[Feat][Graph]Support FULL_DECEDE_ONLY mode for MLA models (#3125)
### What this PR does / why we need it?
Adds support for capturing the Multi-Layer Attention (MLA) decode
operation into an ACL graph. This improves performance by compiling the
attention kernel for single-token decoding.
Key changes include:
- Implementing the graph capture logic for the MLA kernel, including
workspace management and parameter updates.
- Modifying the rotary embedding (RoPE) handling to use pre-allocated
tensors, which is a requirement for graph capture.
- Adding a `build_for_graph_capture` method to the MLA metadata builder
to create dummy metadata during the graph compilation phase.
Known issues:
- Currently, MTP is not supported in FULL_DECEDE_ONLY mode -- we're
working on a fix
- We are preparing to remove update_mla_attn_params with
auto_dispatch_capture
### Does this PR introduce _any_ user-facing change?
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
### How was this patch tested?
- vLLM version: v0.11.0
---------
Signed-off-by: panchao-hub <315134829@qq.com>
Signed-off-by: p00465316 <panchao13@huawei.com>
Co-authored-by: p00465316 <panchao13@huawei.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -20,6 +20,8 @@ Compare the outputs of vLLM with and without aclgraph.
|
||||
Run `pytest tests/compile/test_aclgraph.py`.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
|
||||
@@ -73,3 +75,76 @@ def test_models_with_aclgraph(
|
||||
name_0="vllm_eager_outputs",
|
||||
name_1="vllm_aclgraph_outputs",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
def test_models_with_aclgraph_full_decode_only(
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||
prompts = [
|
||||
('Solve the following math problem step by step.'
|
||||
'The last line of your response should be of the form Answer: '
|
||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
||||
'In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$'
|
||||
'be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$,'
|
||||
'$\\angle BDC = 90^\\circ$. Suppose $AD = 1$ and $\\frac{BD}{CD} = \\frac{3}{2}$.'
|
||||
'If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$,'
|
||||
'where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.'
|
||||
),
|
||||
('Solve the following math problem step by step.'
|
||||
'The last line of your response should be of the form Answer: '
|
||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
||||
'Let $ABCD$ be a unit square in the plane. Points $X$ and $Y$ are chosen'
|
||||
'independently and uniformly at random on the perimeter of $ABCD$.'
|
||||
'If the expected value of the area of triangle $\\triangle AXY$'
|
||||
'can be expressed as $\\frac{m}{n}$, for relatively prime positive'
|
||||
'integers $m$ and $n$, compute $m+n$.'),
|
||||
('Solve the following math problem step by step.'
|
||||
'The last line of your response should be of the form Answer: '
|
||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
||||
'Let $a, b, c$ be distinct numbers such that the equations $x^2 + ax + 1 = 0$'
|
||||
'and $x^2 + bx + c = 0$ have a common real root, and the equations $x^2 + x + a = 0$'
|
||||
'and $x^2 + cx + b = 0$ also have a common real root.'
|
||||
'Compute the sum $a + b + c$.')
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=5,
|
||||
n=1,
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
top_k=1)
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
enforce_eager=False,
|
||||
compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY"},
|
||||
) as runner:
|
||||
vllm_aclgraph_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
enforce_eager=True,
|
||||
) as runner:
|
||||
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
vllm_aclgraph_outputs_list = []
|
||||
for output in vllm_aclgraph_outputs:
|
||||
vllm_aclgraph_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
vllm_eager_outputs_list = []
|
||||
for output in vllm_eager_outputs:
|
||||
vllm_eager_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_eager_outputs_list,
|
||||
outputs_1_lst=vllm_aclgraph_outputs_list,
|
||||
name_0="vllm_eager_outputs",
|
||||
name_1="vllm_aclgraph_outputs",
|
||||
)
|
||||
|
||||
@@ -461,11 +461,13 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
|
||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_decode_without_graph(self,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_up_proj):
|
||||
mock_up_proj,
|
||||
mock_get_forward_context):
|
||||
num_tokens = 100
|
||||
block_size = 4
|
||||
q_nope = torch.randn(num_tokens, self.impl.num_heads,
|
||||
@@ -487,6 +489,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
mock_up_proj.return_value = torch.randn(num_tokens,
|
||||
self.impl.num_heads,
|
||||
self.impl.v_head_dim)
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe,
|
||||
block_size, metadata)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
@@ -614,12 +617,13 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
|
||||
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
||||
|
||||
@patch('vllm_ascend.attention.mla_v1.get_forward_context')
|
||||
@patch("torch.npu.stream")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
|
||||
mock_get_multistream_comm_context,
|
||||
mock_npu_stream):
|
||||
mock_get_multistream_comm_context, mock_npu_stream,
|
||||
mock_get_forward_context):
|
||||
B = 2
|
||||
N = self.impl.num_kv_heads
|
||||
BS = 100
|
||||
@@ -644,6 +648,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
]
|
||||
mock_get_multistream_comm_context.return_value = None
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
||||
attn_metadata)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user