[main] mlp weight prefetch in Qwen Dense Models (#2816)

### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.

### Does this PR introduce _any_ user-facing change?
 No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: main
- vLLM main:
a1213fae5f

Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
This commit is contained in:
rjg-lyh
2025-09-11 21:20:09 +08:00
committed by GitHub
parent c3c2221503
commit 0005479b9c
17 changed files with 313 additions and 24 deletions

View File

@@ -38,7 +38,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor):
@pytest.mark.parametrize("is_310p_return", [True, False])
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj",
side_effect=lambda x: None)
def test_SiluAndMul_forward(mock_maybe_prefetch_mlp_down_proj,
mock_maybe_wait_prefetch_done, mock_swiglu,
is_310p_return, dummy_tensor):
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
layer = SiluAndMul()
@@ -49,9 +54,15 @@ def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
else:
expected_arg = dummy_tensor
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
# assert mock_swiglu.call_count == 1
mock_swiglu.assert_called_once()
# assert mock_maybe_wait_prefetch_done.call_count == 1
mock_maybe_wait_prefetch_done.assert_called_once()
actual_arg = mock_swiglu.call_args[0][0]
assert torch.allclose(
actual_arg,

View File

@@ -30,9 +30,11 @@ def mock_add_rms_norm(x, residual, weight, eps):
[None, torch.randn(4, 8, dtype=torch.float32)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
def test_RMSNorm_forward(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done, mock_add_rmsnorm,
mock_rmsnorm, is_310p_return, residual, dummy_tensor):
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
@@ -45,13 +47,17 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
expected_out_x = expected_arg_x + 1
expected_out_residual = expected_arg_x.to(residual.dtype)
mock_maybe_chunk_residual.assert_called_once()
mock_rmsnorm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
expected_out_x = 2 * dummy_tensor
expected_out_residual = 2 * residual
mock_maybe_chunk_residual.assert_called_once()
mock_add_rmsnorm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
@@ -64,9 +70,11 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=mock_maybe_chunk_residual)
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done,
mock_add_rms_norm, mock_is310p):
x = torch.randn(4, 512, dtype=torch.bfloat16)
residual = torch.randn(16, 512, dtype=torch.bfloat16)
@@ -79,6 +87,7 @@ def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
mock_maybe_chunk_residual.assert_called_once()
mock_add_rms_norm.assert_called_once()
mock_maybe_wait_prefetch_done.assert_called_once()
assert out_residual.size(0) == 4
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)

View File

@@ -275,7 +275,12 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
@patch("torch_npu.npu_add_rms_norm")
@patch("torch_npu.npu_rms_norm")
def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
@patch("torch.ops.vllm.maybe_chunk_residual",
side_effect=lambda x, residual: residual)
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
mock_maybe_wait_prefetch_done,
mock_rms_norm, mock_add_norm,
mock_distributed, base_config,
vllm_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))