[Feat] Support routing replay (#6696)

### What this PR does / why we need it?

[Feat] Support routing replay
same as https://github.com/vllm-project/vllm-ascend/pull/6666
resubmit  because of DOC failure

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: liyongwen <1310439159@qq.com>
Signed-off-by: Li-Yongwen <63399187+Li-Yongwen@users.noreply.github.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
Li-Yongwen
2026-02-26 10:22:47 +08:00
committed by GitHub
parent a9cca0c5c4
commit 2870f7c8ad
7 changed files with 190 additions and 0 deletions

View File

@@ -126,6 +126,8 @@ e2e-multicard-2-cards:
estimated_time: 70 estimated_time: 70
- name: tests/e2e/multicard/2-cards/test_qwen3_moe.py - name: tests/e2e/multicard/2-cards/test_qwen3_moe.py
estimated_time: 1050 estimated_time: 1050
- name: tests/e2e/multicard/2-cards/test_qwen3_moe_routing_replay.py
estimated_time: 1050
- name: tests/e2e/multicard/2-cards/test_single_request_aclgraph.py - name: tests/e2e/multicard/2-cards/test_single_request_aclgraph.py
estimated_time: 215 estimated_time: 215
- name: tests/e2e/multicard/2-cards/test_disaggregated_encoder.py - name: tests/e2e/multicard/2-cards/test_disaggregated_encoder.py

View File

@@ -0,0 +1,32 @@
import os
from unittest.mock import patch
from tests.e2e.conftest import VllmRunner
from vllm import SamplingParams
from vllm.sampling_params import RequestOutputKind
@patch.dict(os.environ, {"OMP_NUM_THREADS": "1"})
def test_qwen3_moe_routing_replay():
prompts = [
"Hello, please introduce yourself.",
]
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
tensor_parallel_size=2,
enable_expert_parallel=True,
cudagraph_capture_sizes=[1, 2, 4, 8],
distributed_executor_backend="mp",
enable_return_routed_experts=True,
) as vllm_model:
sampling_params = SamplingParams(
max_tokens=5,
temperature=0.8,
top_p=0.95,
output_kind=RequestOutputKind.FINAL_ONLY
)
inputs = vllm_model.get_inputs(prompts=prompts)
outputs = vllm_model.model.generate(prompts=inputs, sampling_params=sampling_params)
assert outputs[0].finished
assert len(outputs[0].outputs[0].text) > 0
assert outputs[0].outputs[0].routed_experts.size > 0

View File

@@ -0,0 +1,58 @@
from unittest.mock import patch, MagicMock
import uuid
import torch
from tests.ut.base import TestBase
from vllm_ascend.patch.worker.patch_routed_experts_capturer import RoutedExpertsCapturer
from vllm.config import ModelConfig, VllmConfig
from vllm.config.parallel import ParallelConfig
from transformers import PretrainedConfig
from vllm.platforms import current_platform
class MockVllmConfig:
def __init__(self):
self.model_config = MagicMock()
self.model_config.hf_text_config.num_hidden_layers = 1
self.model_config.hf_text_config.num_experts_per_tok = 1
self.parallel_config = MagicMock()
self.parallel_config.data_parallel_rank = 0
self.instance_id = uuid.uuid4().hex
class TestPatchRoutedExpertsCapturer(TestBase):
def setUp(self):
RoutedExpertsCapturer.create()
self.capturer = RoutedExpertsCapturer.get_instance()
self.vllm_config = MockVllmConfig()
def test_init_buffer(self):
max_num_batched_tokens = 1
max_num_kv_tokens = 1
with patch(
target="vllm_ascend.patch.worker.patch_routed_experts_capturer.get_tensor_model_parallel_rank",
return_value=True
):
current_platform.device_name = "cpu"
self.capturer.init_buffer(
max_num_batched_tokens=max_num_batched_tokens,
max_num_kv_tokens=max_num_kv_tokens,
vllm_config=self.vllm_config,
)
self.assertEqual(
self.capturer._device_buffer.shape,
(
max_num_batched_tokens,
self.vllm_config.model_config.hf_text_config.num_hidden_layers,
self.vllm_config.model_config.hf_text_config.num_experts_per_tok,
)
)
self.assertEqual(self.capturer._device_buffer.dtype, torch.int32)
self.assertEqual(self.capturer._device_buffer.device.type, current_platform.device_name)
def tearDown(self):
self.capturer.clear_buffer()
self.capturer.cleanup()

View File

@@ -26,6 +26,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import RoutedExpertsCapturer
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import vllm_version_is
@@ -122,6 +123,13 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
) )
if layer.vllm_config.model_config is not None and layer.vllm_config.model_config.enable_return_routed_experts:
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capturer.capture(
layer_id=layer.layer_id,
topk_ids=topk_ids,
)
if zero_expert_num > 0 and zero_expert_type is not None: if zero_expert_num > 0 and zero_expert_type is not None:
topk_ids, topk_weights, zero_expert_result = zero_experts_compute( topk_ids, topk_weights, zero_expert_result = zero_experts_compute(

View File

@@ -34,5 +34,6 @@ import vllm_ascend.patch.worker.patch_qwen3_next # noqa
import vllm_ascend.patch.worker.patch_v2_eagle # noqa import vllm_ascend.patch.worker.patch_v2_eagle # noqa
import vllm_ascend.patch.worker.patch_v2_uva # noqa import vllm_ascend.patch.worker.patch_v2_uva # noqa
import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa
import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa
import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa
import vllm_ascend.patch.worker.patch_kimi_k25 # noqa import vllm_ascend.patch.worker.patch_kimi_k25 # noqa

View File

@@ -0,0 +1,68 @@
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
_BUFFER_PREFIX,
_LOCK_FILE_PREFIX,
RoutedExpertsCapturer,
_create_or_attach_shared_memory,
logger,
)
from vllm.platforms import current_platform
def init_buffer(
self,
max_num_batched_tokens: int,
max_num_kv_tokens: int,
vllm_config: VllmConfig,
) -> None:
"""
Initialize the device buffer and optionally shared memory buffer.
Args:
max_num_batched_tokens: Maximum number of tokens in a batch.
max_num_kv_tokens: Maximum number of KV tokens for shared memory.
vllm_config: vllm configuration containing layer and expert info.
"""
if self._device_buffer is not None:
raise RuntimeError("Device buffer has already been initialized")
hf_config = vllm_config.model_config.hf_text_config
num_layers = hf_config.num_hidden_layers
num_experts_per_tok = hf_config.num_experts_per_tok
# Initialize device buffer
self._device_buffer = torch.zeros(
(max_num_batched_tokens, num_layers, num_experts_per_tok),
dtype=torch.int32,
device=current_platform.device_name,
)
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
if get_tensor_model_parallel_rank() != 0:
return
# Initialize shared memory
shape = (max_num_kv_tokens, num_layers, num_experts_per_tok)
buffer_size = int(np.prod(shape)) * np.dtype(np.int32).itemsize
instance_id = vllm_config.instance_id
self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}_{self.dp_rank}.lock"
shm_name = f"{_BUFFER_PREFIX}_{instance_id}_{self.dp_rank}"
self._shm = _create_or_attach_shared_memory(shm_name, buffer_size, self._lock_file)
self._host_buffer_view = np.ndarray(shape, dtype=np.int32, buffer=self._shm.buf)
self._host_buffer_view.fill(0)
logger.debug(
"Created shared memory buffer '%s' with shape %s",
shm_name,
shape,
)
# Patch for _device_buffer's initialization(device="cuda" -> device=current_platform.device_name).
# TODO Remove this patch when pr(https://github.com/vllm-project/vllm/pull/34336) is merged.
RoutedExpertsCapturer.init_buffer = init_buffer

View File

@@ -129,6 +129,7 @@ from vllm_ascend.ascend_forward_context import ( # isort: skip
set_mc2_mask, set_mc2_mask,
set_mc2_tokens_capacity, set_mc2_tokens_capacity,
) )
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import RoutedExpertsCapturer
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped] import xgrammar as xgr # type: ignore[import-untyped]
@@ -373,6 +374,7 @@ class NPUModelRunner(GPUModelRunner):
self.intermediate_tensors: IntermediateTensors | None = None self.intermediate_tensors: IntermediateTensors | None = None
self.reorder_batch_threshold: int | None = None self.reorder_batch_threshold: int | None = None
self.long_seq_metadata = None self.long_seq_metadata = None
self.cpu_slot_mapping = None
@property @property
def use_cp(self) -> bool: def use_cp(self) -> bool:
@@ -1050,6 +1052,12 @@ class NPUModelRunner(GPUModelRunner):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | IntermediateTensors | None: ) -> ModelRunnerOutput | IntermediateTensors | None:
if self.vllm_config.model_config.enable_return_routed_experts:
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capturer.clear_buffer()
else:
logger.warning("RoutedExpertsCapturer is not initialized.")
if self.execute_model_state is not None: if self.execute_model_state is not None:
raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.")
# self._draft_token_ids is None when `input_fits_in_drafter=False` # self._draft_token_ids is None when `input_fits_in_drafter=False`
@@ -1428,6 +1436,14 @@ class NPUModelRunner(GPUModelRunner):
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata() get_kv_transfer_group().clear_connector_metadata()
if self.model_config.enable_return_routed_experts:
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None:
capturer.save_captured_experts(indices=self.cpu_slot_mapping)
else:
logger.warning("RoutedExpertsCapturer is not initialized.")
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(
req_ids=req_ids_output_copy, req_ids=req_ids_output_copy,
req_id_to_index=req_id_to_index_output_copy, req_id_to_index=req_id_to_index_output_copy,
@@ -1902,6 +1918,8 @@ class NPUModelRunner(GPUModelRunner):
num_tokens_padded, num_tokens_padded,
slot_mapping, slot_mapping,
) )
if self.model_config.enable_return_routed_experts and kv_cache_gid == 0:
self.cpu_slot_mapping = slot_mapping.cpu().numpy()
return blk_table_tensor, slot_mapping return blk_table_tensor, slot_mapping
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
@@ -2364,6 +2382,9 @@ class NPUModelRunner(GPUModelRunner):
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches) get_kv_transfer_group().register_kv_caches(kv_caches)
if self.model_config.enable_return_routed_experts:
self.init_routed_experts_capturer()
def _align_memory(self, tensor: torch.Tensor, alignment: int) -> torch.Tensor: def _align_memory(self, tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr() data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment aligned_addr = (data_ptr + alignment - 1) // alignment * alignment