[main] mlp weight prefetch in Qwen Dense Models (#2816)
### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: main
- vLLM main:
a1213fae5f
Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
This commit is contained in:
@@ -31,7 +31,9 @@ from tests.e2e.conftest import VllmRunner
|
||||
|
||||
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
||||
|
||||
QWEN_DENSE_MODELS = ["Qwen/QwQ-32B", "Qwen/Qwen-32B"]
|
||||
QWEN_DENSE_MODELS = [
|
||||
"vllm-ascend/Qwen3-8B-W8A8", "vllm-ascend/Qwen2.5-0.5B-Instruct-W8A8"
|
||||
]
|
||||
|
||||
|
||||
def test_models_distributed_QwQ():
|
||||
@@ -170,6 +172,29 @@ def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
dtype="auto",
|
||||
tensor_parallel_size=4,
|
||||
tensor_parallel_size=2,
|
||||
quantization="ascend",
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_PREFETCH_MLP": "1"})
|
||||
def test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight(
|
||||
model, enforce_eager):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
max_tokens = 5
|
||||
|
||||
with VllmRunner(
|
||||
snapshot_download(model),
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
dtype="auto",
|
||||
tensor_parallel_size=2,
|
||||
quantization="ascend",
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
@@ -38,7 +38,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor):
|
||||
|
||||
@pytest.mark.parametrize("is_310p_return", [True, False])
|
||||
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
|
||||
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj",
|
||||
side_effect=lambda x: None)
|
||||
def test_SiluAndMul_forward(mock_maybe_prefetch_mlp_down_proj,
|
||||
mock_maybe_wait_prefetch_done, mock_swiglu,
|
||||
is_310p_return, dummy_tensor):
|
||||
|
||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||
layer = SiluAndMul()
|
||||
@@ -49,9 +54,15 @@ def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
|
||||
else:
|
||||
expected_arg = dummy_tensor
|
||||
|
||||
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
|
||||
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
|
||||
|
||||
# assert mock_swiglu.call_count == 1
|
||||
mock_swiglu.assert_called_once()
|
||||
|
||||
# assert mock_maybe_wait_prefetch_done.call_count == 1
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
|
||||
actual_arg = mock_swiglu.call_args[0][0]
|
||||
assert torch.allclose(
|
||||
actual_arg,
|
||||
|
||||
@@ -30,9 +30,11 @@ def mock_add_rms_norm(x, residual, weight, eps):
|
||||
[None, torch.randn(4, 8, dtype=torch.float32)])
|
||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=mock_maybe_chunk_residual)
|
||||
def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
|
||||
def test_RMSNorm_forward(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done, mock_add_rmsnorm,
|
||||
mock_rmsnorm, is_310p_return, residual, dummy_tensor):
|
||||
|
||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||
@@ -45,13 +47,17 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
|
||||
expected_out_x = expected_arg_x + 1
|
||||
expected_out_residual = expected_arg_x.to(residual.dtype)
|
||||
|
||||
mock_maybe_chunk_residual.assert_called_once()
|
||||
mock_rmsnorm.assert_called_once()
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
expected_out_x = 2 * dummy_tensor
|
||||
expected_out_residual = 2 * residual
|
||||
mock_maybe_chunk_residual.assert_called_once()
|
||||
mock_add_rmsnorm.assert_called_once()
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
@@ -64,9 +70,11 @@ def test_RMSNorm_forward(mock_maybe_chunk_residual, mock_add_rmsnorm,
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=mock_maybe_chunk_residual)
|
||||
def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done,
|
||||
mock_add_rms_norm, mock_is310p):
|
||||
x = torch.randn(4, 512, dtype=torch.bfloat16)
|
||||
residual = torch.randn(16, 512, dtype=torch.bfloat16)
|
||||
@@ -79,6 +87,7 @@ def test_RMSNorm_forward_with_flashcomm_v1(mock_maybe_chunk_residual,
|
||||
|
||||
mock_maybe_chunk_residual.assert_called_once()
|
||||
mock_add_rms_norm.assert_called_once()
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
assert out_residual.size(0) == 4
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
|
||||
@@ -275,7 +275,12 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
|
||||
@patch("torch_npu.npu_add_rms_norm")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=lambda x, residual: residual)
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done,
|
||||
mock_rms_norm, mock_add_norm,
|
||||
mock_distributed, base_config,
|
||||
vllm_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
|
||||
Reference in New Issue
Block a user