support FULL graph mode for GQA (#3970)
### What this PR does / why we need it?
The current library only supports the FullDecodeOnly graph mode, which
enables full graph execution during the decode. This PR extends support
to allow full graph execution in both the prefill and decode, referred
to as FULL graph mode.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -29,7 +29,7 @@ from tests.e2e.conftest import VllmRunner
|
||||
from tests.e2e.model_utils import check_outputs_equal
|
||||
|
||||
|
||||
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
|
||||
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY():
|
||||
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||
prompts = [
|
||||
@@ -42,15 +42,64 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=False,
|
||||
compilation_config={"cudagraph_mode":
|
||||
"FULL_DECODE_ONLY"}) as runner:
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
|
||||
}) as runner:
|
||||
vllm_fullgraph_outputs = runner.model.generate(prompts,
|
||||
sampling_params)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=False,
|
||||
) as runner:
|
||||
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
vllm_fullgraph_outputs_list = []
|
||||
for output in vllm_fullgraph_outputs:
|
||||
vllm_fullgraph_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
vllm_eager_outputs_list = []
|
||||
for output in vllm_eager_outputs:
|
||||
vllm_eager_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_eager_outputs_list,
|
||||
outputs_1_lst=vllm_fullgraph_outputs_list,
|
||||
name_0="vllm_eager_outputs",
|
||||
name_1="vllm_fullgraph_outputs",
|
||||
)
|
||||
|
||||
|
||||
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL():
|
||||
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||
prompts = [
|
||||
"Hello, my name is", "The president of the United States is",
|
||||
"The capital of France is", "The future of AI is"
|
||||
]
|
||||
model = "Qwen/Qwen3-30B-A3B"
|
||||
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
|
||||
with VllmRunner(model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=False,
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
|
||||
}) as runner:
|
||||
vllm_fullgraph_outputs = runner.model.generate(prompts,
|
||||
sampling_params)
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
enforce_eager=False,
|
||||
) as runner:
|
||||
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ def mtp_correctness(sampling_config: SamplingParams,
|
||||
|
||||
graph_mode_str = "PIECEWISE"
|
||||
if graph_mode == CUDAGraphMode.FULL:
|
||||
graph_mode_str = "FULL"
|
||||
graph_mode_str = "FULL_DECODE_ONLY"
|
||||
|
||||
with VllmRunner(
|
||||
model_name,
|
||||
@@ -63,7 +63,9 @@ def mtp_correctness(sampling_config: SamplingParams,
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=2000,
|
||||
compilation_config=CompilationConfig(
|
||||
cudagraph_mode=graph_mode_str),
|
||||
cudagraph_mode=graph_mode_str,
|
||||
cudagraph_capture_sizes=[12],
|
||||
),
|
||||
additional_config={"ascend_scheduler_config": {
|
||||
"enabled": False
|
||||
}}) as spec_llm:
|
||||
|
||||
@@ -286,10 +286,12 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_forward_prefill_no_cache(self, mock_flash_attention,
|
||||
mock_reshape_cache):
|
||||
mock_reshape_cache,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in PrefillNoCache state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -297,6 +299,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
@@ -316,7 +320,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
def test_forward_prefill_cache_hit(self,
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_prefill_cache_hit(self, mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
@@ -326,8 +331,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_npu_fused_infer_attention_score.return_value = (output, 1)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
@@ -340,18 +343,23 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
def test_forward_decode_only(self, mock_paged_attention,
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_decode_only(self, mock_get_forward_context,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_get_forward_context):
|
||||
mock_paged_attention):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -378,115 +386,11 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_graph_params')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch.npu.graph_task_group_end')
|
||||
@patch('torch.npu.graph_task_group_begin')
|
||||
@patch('torch.npu.ExternalEvent')
|
||||
@patch('torch_npu.npu.current_stream')
|
||||
@patch('vllm_ascend.attention.attention_v1.weak_ref_tensors')
|
||||
def test_paged_attention_with_existing_workspace(
|
||||
self,
|
||||
mock_get_forward_context,
|
||||
mock_get_graph_params,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_paged_attention,
|
||||
mock_graph_begin,
|
||||
mock_graph_end,
|
||||
mock_external_event_class,
|
||||
mock_current_stream,
|
||||
mock_weak_ref_tensors,
|
||||
):
|
||||
graph_params = MagicMock()
|
||||
attn_metadata = MagicMock()
|
||||
num_tokens = 10
|
||||
|
||||
graph_params.workspaces = {num_tokens: 10}
|
||||
graph_params.events = {num_tokens: []}
|
||||
graph_params.attn_params = {num_tokens: []}
|
||||
graph_params.handles = {num_tokens: []}
|
||||
|
||||
query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size]
|
||||
key_cache = MagicMock()
|
||||
value_cache = MagicMock()
|
||||
num_kv_heads = 4
|
||||
num_heads = 8
|
||||
scale = 0.1
|
||||
output = torch.randn(2, 5, 8)
|
||||
|
||||
self_obj = MagicMock()
|
||||
self_obj.key_cache = key_cache
|
||||
self_obj.value_cache = value_cache
|
||||
self_obj.num_kv_heads = num_kv_heads
|
||||
self_obj.num_heads = num_heads
|
||||
self_obj.scale = scale
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_current_stream.return_value = mock_stream
|
||||
mock_event_instance = MagicMock()
|
||||
mock_external_event_class.return_value = mock_event_instance
|
||||
|
||||
mock_handle = MagicMock()
|
||||
mock_graph_end.return_value = mock_handle
|
||||
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
self.assertEqual(workspace, 10)
|
||||
|
||||
weak_ref_tensors = MagicMock(side_effect=lambda x: x)
|
||||
|
||||
# 2. Handle graph capturing mode
|
||||
stream = mock_current_stream()
|
||||
event = mock_external_event_class()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
weak_ref_tensors(query),
|
||||
weak_ref_tensors(self_obj.key_cache),
|
||||
weak_ref_tensors(self_obj.value_cache),
|
||||
self_obj.num_kv_heads,
|
||||
self_obj.num_heads,
|
||||
self_obj.scale,
|
||||
weak_ref_tensors(attn_metadata.block_tables),
|
||||
attn_metadata.seq_lens,
|
||||
output,
|
||||
))
|
||||
|
||||
mock_event_instance.wait.assert_called_once_with(mock_stream)
|
||||
mock_event_instance.reset.assert_called_once_with(mock_stream)
|
||||
self.assertEqual(len(graph_params.events[num_tokens]), 1)
|
||||
self.assertEqual(len(graph_params.attn_params[num_tokens]), 1)
|
||||
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=True)
|
||||
mock_get_graph_params.return_value = graph_params
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
|
||||
mock_fused_infer_attention_score,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -494,6 +398,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty(10, 8, 64)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10] * 10)
|
||||
@@ -512,12 +418,12 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
assert output.shape == (10, 8, 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_decode_only_swa_seq_len_mismatch(
|
||||
self, mock_fused_infer_attention_score, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache, mock_get_forward_context):
|
||||
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
|
||||
mock_paged_attention, mock_get_forward_context):
|
||||
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -535,11 +441,11 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.num_decodes = 10
|
||||
metadata.num_prefills = 0
|
||||
|
||||
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
|
||||
64), 1)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, 64),
|
||||
torch.ones(10, 8, 64))
|
||||
|
||||
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
@@ -548,11 +454,13 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
|
||||
def test_forward_head_size_192(self, mock_vanilla_prefill,
|
||||
mock_npu_reshape_and_cache, mock_is_310p):
|
||||
mock_npu_reshape_and_cache, mock_is_310p,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass when head_size is 192"""
|
||||
|
||||
self.impl.head_size = 192
|
||||
@@ -562,6 +470,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
@@ -580,11 +490,12 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_vanilla_prefill.assert_called_once()
|
||||
assert output.shape == (10, 8 * 192)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
def test_forward_normal_v1_situation(self,
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in normal V1 situation"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -592,8 +503,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_npu_fused_infer_attention_score.return_value = (output, 1)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
@@ -604,6 +513,10 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = 10
|
||||
layer = self.layer_no_quant
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
@@ -615,7 +528,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
|
||||
def test_forward_310p_device(self, mock_is_310p,
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
def test_forward_310p_device(self, mock_get_forward_context, mock_is_310p,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_npu_format_cast):
|
||||
@@ -626,8 +540,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_npu_fused_infer_attention_score.return_value = (output, 1)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
@@ -641,6 +553,11 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
mock_npu_format_cast.return_value = metadata.attn_mask
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||
torch.ones(
|
||||
10, 8, 64))
|
||||
|
||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user