[main] mlp weight prefetch in Qwen Dense Models (#2816)
### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: main
- vLLM main:
a1213fae5f
Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
This commit is contained in:
@@ -38,7 +38,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor):
|
||||
|
||||
@pytest.mark.parametrize("is_310p_return", [True, False])
|
||||
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
|
||||
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj",
|
||||
side_effect=lambda x: None)
|
||||
def test_SiluAndMul_forward(mock_maybe_prefetch_mlp_down_proj,
|
||||
mock_maybe_wait_prefetch_done, mock_swiglu,
|
||||
is_310p_return, dummy_tensor):
|
||||
|
||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||
layer = SiluAndMul()
|
||||
@@ -49,9 +54,15 @@ def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
|
||||
else:
|
||||
expected_arg = dummy_tensor
|
||||
|
||||
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
|
||||
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
|
||||
|
||||
# assert mock_swiglu.call_count == 1
|
||||
mock_swiglu.assert_called_once()
|
||||
|
||||
# assert mock_maybe_wait_prefetch_done.call_count == 1
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
|
||||
actual_arg = mock_swiglu.call_args[0][0]
|
||||
assert torch.allclose(
|
||||
actual_arg,
|
||||
|
||||
@@ -30,9 +30,11 @@ def mock_add_rms_norm(x, residual, weight, eps):
|
||||
[None, torch.randn(4, 8, dtype=torch.float32)])
|
||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=mock_maybe_chunk_residual)
|
||||
def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
|
||||
def test_RMSNorm_forward(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done, mock_add_rmsnorm,
|
||||
mock_rmsnorm, is_310p_return, residual, dummy_tensor):
|
||||
|
||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||
@@ -45,13 +47,17 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
|
||||
expected_out_x = expected_arg_x + 1
|
||||
expected_out_residual = expected_arg_x.to(residual.dtype)
|
||||
|
||||
mock_maybe_chunk_residual.assert_called_once()
|
||||
mock_rmsnorm.assert_called_once()
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
expected_out_x = 2 * dummy_tensor
|
||||
expected_out_residual = 2 * residual
|
||||
mock_maybe_chunk_residual.assert_called_once()
|
||||
mock_add_rmsnorm.assert_called_once()
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
@@ -64,9 +70,11 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=mock_maybe_chunk_residual)
|
||||
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done,
|
||||
mock_add_rms_norm, mock_is310p):
|
||||
x = torch.randn(4, 512, dtype=torch.bfloat16)
|
||||
residual = torch.randn(16, 512, dtype=torch.bfloat16)
|
||||
@@ -79,6 +87,7 @@ def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
|
||||
|
||||
mock_maybe_chunk_residual.assert_called_once()
|
||||
mock_add_rms_norm.assert_called_once()
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
assert out_residual.size(0) == 4
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
|
||||
@@ -275,7 +275,12 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
|
||||
@patch("torch_npu.npu_add_rms_norm")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=lambda x, residual: residual)
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done,
|
||||
mock_rms_norm, mock_add_norm,
|
||||
mock_distributed, base_config,
|
||||
vllm_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
|
||||
Reference in New Issue
Block a user