[Refactor] MLP weight prefetch to consistency with MoE Model's prefetching in terms of code and usage (#6442)

### What this PR does / why we need it?
Refactor MLP weight prefetch to consistency with MoE Model's prefetching
in terms of code and usage.
Environments VLLM_ASCEND_ENABLE_PREFETCH_MLP,
VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE and
VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE is removed, usage as following:

--additional-config '{"weight_prefetch_config": { "enabled": true,
"prefetch_ratio": {"mlp": { "gate_up": 1.0, "down": 1.0} }}}'

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
Nengjun Ma
2026-02-04 09:08:18 +08:00
committed by GitHub
parent fa56abea9f
commit 78fad4e348
18 changed files with 250 additions and 171 deletions

View File

@@ -222,7 +222,7 @@ def test_qwen3_dense_fc1_tp2(model):
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
def test_qwen3_dense_prefetch_mlp_weight_tp2(model):
example_prompts = [
"Hello, my name is",
@@ -236,6 +236,7 @@ def test_qwen3_dense_prefetch_mlp_weight_tp2(model):
tensor_parallel_size=2,
cudagraph_capture_sizes=[1, 2, 4, 8],
quantization="ascend",
additional_config={"weight_prefetch_config": {"enabled": True}},
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@@ -57,7 +57,6 @@ async def test_models(model: str) -> None:
env_dict = {
"TASK_QUEUE_ENABLE": "1",
"HCCL_OP_EXPANSION_MODE": "AIV",
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1",
}
server_args = [
"--async-scheduling",
@@ -74,7 +73,7 @@ async def test_models(model: str) -> None:
"--compilation-config",
'{"cudagraph_mode": "FULL_DECODE_ONLY"}',
"--additional-config",
'{"pa_shape_list":[48,64,72,80]}',
'{"pa_shape_list":[48,64,72,80],"weight_prefetch_config":{"enabled":true}}',
"--block-size",
"128",
"--trust-remote-code",

View File

@@ -83,7 +83,6 @@ async def test_models(model: str, mode: str, tp_size: int) -> None:
"TASK_QUEUE_ENABLE": "1",
"HCCL_OP_EXPANSION_MODE": "AIV",
"VLLM_ASCEND_ENABLE_FLASHCOMM": "1",
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"
}
compilation_config = {
"cudagraph_mode":
@@ -98,7 +97,8 @@ async def test_models(model: str, mode: str, tp_size: int) -> None:
str(port), "--max-model-len", "40960", "--max-num-batched-tokens",
"40960", "--block-size", "128", "--trust-remote-code",
"--reasoning-parser", "qwen3", "--gpu-memory-utilization", "0.9",
"--async-scheduling"
"--async-scheduling", "--additional-config",
'{"weight_prefetch_config":{"enabled":true}}',
]
if mode == "single":
server_args.append("--enforce-eager")

View File

@@ -72,7 +72,6 @@ async def test_models(model: str, tp_size: int) -> None:
"OMP_PROC_BIND": "false",
"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "1",
"VLLM_ASCEND_ENABLE_FLASHCOMM": "1",
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"
}
server_args = [
"--quantization", "ascend", "--tensor-parallel-size",
@@ -82,7 +81,8 @@ async def test_models(model: str, tp_size: int) -> None:
"0.9", "--block-size", "128", "--max-num-seqs", "256",
"--enforce-eager", "--max-model-len", "35840",
"--max-num-batched-tokens", "35840", "--additional-config",
'{"enable_weight_nz_layout":true}', "--compilation-config",
'{"enable_weight_nz_layout":true, "weight_prefetch_config":{"enabled": true}}',
"--compilation-config",
'{"cudagraph_mode":"FULL_DECODE_ONLY", "cudagraph_capture_sizes":[1,8,24,48,60]}'
]
with RemoteOpenAIServer(model,

View File

@@ -75,8 +75,7 @@ async def test_models(model: str, mode: str, tp_size: int) -> None:
"OMP_PROC_BIND": "false",
"HCCL_OP_EXPANSION_MODE": "AIV",
"VLLM_ASCEND_ENABLE_FLASHCOMM": "1",
"VLLM_ASCEND_ENABLE_DEBSE_OPTIMIZE": "1",
"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"
"VLLM_ASCEND_ENABLE_DEBSE_OPTIMIZE": "1"
}
server_args = [
"--tensor-parallel-size",
@@ -86,7 +85,7 @@ async def test_models(model: str, mode: str, tp_size: int) -> None:
"--gpu-memory-utilization", "0.9", "--compilation_config",
'{"cudagraph_mode":"FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1, 8, 24, 48, 60]}',
"--reasoning-parser", "deepseek_r1", "--distributed_executor_backend",
"mp"
"mp", "--additional-config", '{"weight_prefetch_config":{"enabled":true}}'
]
if mode == "single":
server_args.remove("--compilation_config")

View File

@@ -54,11 +54,7 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor, default_vllm_config):
@pytest.mark.skipif(is_310p_hw(), reason="non_310P device unittest case.")
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj", side_effect=lambda x: None)
def test_SiluAndMul_forward(
mock_maybe_prefetch_mlp_down_proj,
mock_maybe_wait_prefetch_done,
mock_swiglu,
dummy_tensor,
default_vllm_config,
@@ -67,15 +63,9 @@ def test_SiluAndMul_forward(
out = layer.forward(dummy_tensor)
expected_arg = dummy_tensor
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
# assert mock_swiglu.call_count == 1
mock_swiglu.assert_called_once()
# assert mock_maybe_wait_prefetch_done.call_count == 1
mock_maybe_wait_prefetch_done.assert_called_once()
actual_arg = mock_swiglu.call_args[0][0]
assert torch.allclose(actual_arg, expected_arg), "npu_swiglu called with unexpected input"
@@ -85,11 +75,7 @@ def test_SiluAndMul_forward(
@pytest.mark.skipif(not is_310p_hw(), reason="310P device unittest case.")
@patch("torch.nn.functional.silu", side_effect=lambda x: x + 1)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj", side_effect=lambda x: None)
def test_SiluAndMul_forward_310p(
mock_maybe_prefetch_mlp_down_proj,
mock_maybe_wait_prefetch_done,
mock_silu,
dummy_tensor,
default_vllm_config,
@@ -99,15 +85,9 @@ def test_SiluAndMul_forward_310p(
h = dummy_tensor.shape[-1] // 2
expected_arg = dummy_tensor[..., :h]
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
# assert mock_silu.call_count == 1
mock_silu.assert_called_once()
# assert mock_maybe_wait_prefetch_done.call_count == 1
mock_maybe_wait_prefetch_done.assert_called_once()
actual_arg = mock_silu.call_args[0][0]
assert torch.allclose(actual_arg, expected_arg), "swiglu called with unexpected input"