support FULL graph mode for GQA (#3970)

### What this PR does / why we need it?
The current library only supports the FullDecodeOnly graph mode, which
enables full graph execution during the decode. This PR extends support
to allow full graph execution in both the prefill and decode, referred
to as FULL graph mode.

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-11-17 10:50:35 +08:00
committed by GitHub
parent c334114f69
commit e38ef2c434
11 changed files with 328 additions and 296 deletions

View File

@@ -180,6 +180,7 @@ jobs:
if: ${{ inputs.type == 'full' }}
run: |
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/test_expert_parallel.py
pytest -sv tests/e2e/multicard/test_external_launcher.py

View File

@@ -29,7 +29,7 @@ from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
prompts = [
@@ -42,15 +42,64 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=False,
compilation_config={"cudagraph_mode":
"FULL_DECODE_ONLY"}) as runner:
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
}) as runner:
vllm_fullgraph_outputs = runner.model.generate(prompts,
sampling_params)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=True,
tensor_parallel_size=2,
enforce_eager=False,
) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
vllm_fullgraph_outputs_list = []
for output in vllm_fullgraph_outputs:
vllm_fullgraph_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
vllm_eager_outputs_list = []
for output in vllm_eager_outputs:
vllm_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_fullgraph_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_fullgraph_outputs",
)
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
model = "Qwen/Qwen3-30B-A3B"
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
with VllmRunner(model,
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=False,
compilation_config={
"cudagraph_mode": "FULL",
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
}) as runner:
vllm_fullgraph_outputs = runner.model.generate(prompts,
sampling_params)
with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=False,
) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)

View File

@@ -46,7 +46,7 @@ def mtp_correctness(sampling_config: SamplingParams,
graph_mode_str = "PIECEWISE"
if graph_mode == CUDAGraphMode.FULL:
graph_mode_str = "FULL"
graph_mode_str = "FULL_DECODE_ONLY"
with VllmRunner(
model_name,
@@ -63,7 +63,9 @@ def mtp_correctness(sampling_config: SamplingParams,
enforce_eager=enforce_eager,
max_model_len=2000,
compilation_config=CompilationConfig(
cudagraph_mode=graph_mode_str),
cudagraph_mode=graph_mode_str,
cudagraph_capture_sizes=[12],
),
additional_config={"ascend_scheduler_config": {
"enabled": False
}}) as spec_llm:

View File

@@ -286,10 +286,12 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_flash_attention')
def test_forward_prefill_no_cache(self, mock_flash_attention,
mock_reshape_cache):
mock_reshape_cache,
mock_get_forward_context):
"""Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -297,6 +299,8 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
mock_get_forward_context.return_value = MagicMock(capturing=False)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillNoCache
metadata.attn_mask = torch.randn(1, 1, 10, 10)
@@ -316,7 +320,8 @@ class TestAscendAttentionBackendImpl(TestBase):
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_prefill_cache_hit(self,
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_prefill_cache_hit(self, mock_get_forward_context,
mock_npu_fused_infer_attention_score,
mock_npu_reshape_and_cache):
"""Test forward pass in PrefillCacheHit state"""
@@ -326,8 +331,6 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
mock_npu_fused_infer_attention_score.return_value = (output, 1)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillCacheHit
metadata.attn_mask = torch.randn(1, 1, 10, 10)
@@ -340,18 +343,23 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_fused_infer_attention_score.return_value = (output,
torch.ones(
10, 8, 64))
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_npu_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
def test_forward_decode_only(self, mock_paged_attention,
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_decode_only(self, mock_get_forward_context,
mock_npu_reshape_and_cache,
mock_get_forward_context):
mock_paged_attention):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -378,115 +386,11 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('vllm_ascend.attention.attention_v1.get_graph_params')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
@patch('torch.npu.graph_task_group_end')
@patch('torch.npu.graph_task_group_begin')
@patch('torch.npu.ExternalEvent')
@patch('torch_npu.npu.current_stream')
@patch('vllm_ascend.attention.attention_v1.weak_ref_tensors')
def test_paged_attention_with_existing_workspace(
self,
mock_get_forward_context,
mock_get_graph_params,
mock_npu_reshape_and_cache,
mock_paged_attention,
mock_graph_begin,
mock_graph_end,
mock_external_event_class,
mock_current_stream,
mock_weak_ref_tensors,
):
graph_params = MagicMock()
attn_metadata = MagicMock()
num_tokens = 10
graph_params.workspaces = {num_tokens: 10}
graph_params.events = {num_tokens: []}
graph_params.attn_params = {num_tokens: []}
graph_params.handles = {num_tokens: []}
query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size]
key_cache = MagicMock()
value_cache = MagicMock()
num_kv_heads = 4
num_heads = 8
scale = 0.1
output = torch.randn(2, 5, 8)
self_obj = MagicMock()
self_obj.key_cache = key_cache
self_obj.value_cache = value_cache
self_obj.num_kv_heads = num_kv_heads
self_obj.num_heads = num_heads
self_obj.scale = scale
mock_stream = MagicMock()
mock_current_stream.return_value = mock_stream
mock_event_instance = MagicMock()
mock_external_event_class.return_value = mock_event_instance
mock_handle = MagicMock()
mock_graph_end.return_value = mock_handle
workspace = graph_params.workspaces.get(num_tokens)
self.assertEqual(workspace, 10)
weak_ref_tensors = MagicMock(side_effect=lambda x: x)
# 2. Handle graph capturing mode
stream = mock_current_stream()
event = mock_external_event_class()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query),
weak_ref_tensors(self_obj.key_cache),
weak_ref_tensors(self_obj.value_cache),
self_obj.num_kv_heads,
self_obj.num_heads,
self_obj.scale,
weak_ref_tensors(attn_metadata.block_tables),
attn_metadata.seq_lens,
output,
))
mock_event_instance.wait.assert_called_once_with(mock_stream)
mock_event_instance.reset.assert_called_once_with(mock_stream)
self.assertEqual(len(graph_params.events[num_tokens]), 1)
self.assertEqual(len(graph_params.attn_params[num_tokens]), 1)
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=True)
mock_get_graph_params.return_value = graph_params
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_paged_attention.assert_called_once()
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
mock_npu_reshape_and_cache):
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
mock_fused_infer_attention_score,
mock_get_forward_context):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -494,6 +398,8 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty(10, 8, 64)
mock_get_forward_context.return_value = MagicMock(capturing=False)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10] * 10)
@@ -512,12 +418,12 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8, 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_decode_only_swa_seq_len_mismatch(
self, mock_fused_infer_attention_score, mock_paged_attention,
mock_npu_reshape_and_cache, mock_get_forward_context):
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
mock_paged_attention, mock_get_forward_context):
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -535,11 +441,11 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.num_decodes = 10
metadata.num_prefills = 0
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, 64),
torch.ones(10, 8, 64))
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
metadata, output)
@@ -548,11 +454,13 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
def test_forward_head_size_192(self, mock_vanilla_prefill,
mock_npu_reshape_and_cache, mock_is_310p):
mock_npu_reshape_and_cache, mock_is_310p,
mock_get_forward_context):
"""Test forward pass when head_size is 192"""
self.impl.head_size = 192
@@ -562,6 +470,8 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 192)
output = torch.empty_like(query)
mock_get_forward_context.return_value = MagicMock(capturing=False)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
@@ -580,11 +490,12 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_vanilla_prefill.assert_called_once()
assert output.shape == (10, 8 * 192)
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_normal_v1_situation(self,
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
mock_npu_fused_infer_attention_score,
mock_npu_reshape_and_cache):
mock_get_forward_context):
"""Test forward pass in normal V1 situation"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -592,8 +503,6 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
mock_npu_fused_infer_attention_score.return_value = (output, 1)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
@@ -604,6 +513,10 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_fused_infer_attention_score.return_value = (output,
torch.ones(
10, 8, 64))
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
@@ -615,7 +528,8 @@ class TestAscendAttentionBackendImpl(TestBase):
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
def test_forward_310p_device(self, mock_is_310p,
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_310p_device(self, mock_get_forward_context, mock_is_310p,
mock_npu_fused_infer_attention_score,
mock_npu_reshape_and_cache,
mock_npu_format_cast):
@@ -626,8 +540,6 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
mock_npu_fused_infer_attention_score.return_value = (output, 1)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
@@ -641,6 +553,11 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_npu_format_cast.return_value = metadata.attn_mask
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_fused_infer_attention_score.return_value = (output,
torch.ones(
10, 8, 64))
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)

View File

@@ -195,6 +195,7 @@ class AscendMetadataForDecode:
class AscendMetadata:
# **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None
fia_attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
@@ -215,6 +216,7 @@ class AscendMetadata:
seq_lens: torch.Tensor = None
seq_lens_list: List[int] = None # type: ignore
actual_seq_lengths_q: List[int] = None # type: ignore
query_start_loc_list: List[int] = None # type: ignore
query_start_loc: torch.Tensor = None
query_lens: torch.Tensor = None
@@ -241,7 +243,8 @@ class AscendMetadata:
class AscendAttentionMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
AttentionCGSupport.ALWAYS
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
@@ -321,6 +324,7 @@ class AscendAttentionMetadataBuilder:
num_actual_tokens_pcp_padded]
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
fia_attn_mask = common_attn_metadata.fia_attn_mask
attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
@@ -471,6 +475,7 @@ class AscendAttentionMetadataBuilder:
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table,
query_start_loc=query_start_loc,
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
query_lens=query_lens,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
@@ -478,6 +483,7 @@ class AscendAttentionMetadataBuilder:
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
slot_mapping=slot_mapping,
attn_mask=attn_mask,
fia_attn_mask=fia_attn_mask,
attn_state=attn_state,
num_prefills=num_prefills,
num_decodes=num_decodes,
@@ -565,6 +571,113 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.dcp_group = get_dcp_group(
).device_group if self.dcp_size > 1 else None
def full_graph_attention(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
num_tokens=0):
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
actual_seq_lengths_kv = attn_metadata.seq_lens_list
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
# Normal V1 situation.
else:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
num_tokens = attn_metadata.query_start_loc_list[-1]
query = query[:num_tokens]
graph_params = get_graph_params()
query_start_loc = attn_metadata.query_start_loc_list
# Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.fia_attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
sparse_mode=3,
scale=self.scale,
)
update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(query), weak_ref_tensors(key),
weak_ref_tensors(value), weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.fia_attn_mask), block_size,
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
self.num_heads, self.scale, weak_ref_tensors(output),
weak_ref_tensors(softmax_lse)))
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.fia_attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
workspace=workspace,
out=[output, softmax_lse],
)
output = output.view(num_tokens, self.num_heads, self.head_size)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return output, num_tokens
def _forward_prefill_no_cache(
self,
query: torch.Tensor,
@@ -692,70 +805,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
output = output.view(batch_size, self.num_heads, self.head_size)
else:
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query),
weak_ref_tensors(self.key_cache),
weak_ref_tensors(self.value_cache),
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
))
torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
return output
def _forward_v1_style(
@@ -819,7 +878,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
scale=self.scale,
sparse_mode=3,
)
return output
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
@@ -1481,47 +1539,51 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded])
if self.pcp_size * self.dcp_size > 1:
intermediate_output = self._forward_pcp_dcp(
query, key, value, kv_cache, attn_metadata, output)
elif attn_type == AttentionType.ENCODER_ONLY:
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
intermediate_output = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
pre_tockens=attn_metadata.max_query_len,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
)[0]
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
intermediate_output = self._forward_prefill_no_cache(
query, key, value, attn_metadata, output, num_tokens)
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
intermediate_output = self._forward_prefill_cache_hit(
query, attn_metadata, output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
intermediate_output = self._forward_decode_only(
query, attn_metadata, output)
# Normal V1 situation.
forward_context: ForwardContext = get_forward_context()
if not forward_context.capturing:
if self.pcp_size * self.dcp_size > 1:
intermediate_output = self._forward_pcp_dcp(
query, key, value, kv_cache, attn_metadata, output)
elif attn_type == AttentionType.ENCODER_ONLY:
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
intermediate_output = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
pre_tockens=attn_metadata.max_query_len,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
)[0]
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
intermediate_output = self._forward_prefill_no_cache(
query, key, value, attn_metadata, output, num_tokens)
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
intermediate_output = self._forward_prefill_cache_hit(
query, attn_metadata, output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
intermediate_output = self._forward_decode_only(
query, attn_metadata, output)
# Normal V1 situation.
else:
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
intermediate_output = self._forward_v1_style(
query, attn_metadata, output)
else:
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
intermediate_output = self._forward_v1_style(
query, attn_metadata, output)
intermediate_output, num_tokens = self.full_graph_attention(
query, key, value, attn_metadata, output)
output[:num_tokens] = intermediate_output[:num_tokens]
return output

View File

@@ -1278,8 +1278,7 @@ class AscendMLAImpl(MLAAttentionImpl):
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, k_nope, **common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty(num_tokens,
@@ -1779,8 +1778,7 @@ class AscendMLAImpl(MLAAttentionImpl):
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
seq_len, num_heads, self.scale, self.num_kv_heads,
**common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,

View File

@@ -88,6 +88,8 @@ class AscendCommonAttentionMetadata:
attn_mask: torch.Tensor = None
fia_attn_mask: torch.Tensor = None
spec_attn_mask: torch.Tensor = None
attn_state: Any = None

View File

@@ -203,48 +203,31 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
(query, key_cache, value, block_tables, attn_mask, block_size,
seq_lens, query_start_loc, num_kv_heads, num_heads, scale,
attn_output, softmax_lse) = param
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
# might encounter a bigger workspace, while currently we use max_model_len to
# calculate max workspace in capturing. So additional get_workspace is added
# here to avoid such bugs.
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
seq_lens = forward_context.attn_metadata[key].seq_lens_list
query_start_loc = forward_context.attn_metadata[
key].query_start_loc_list
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key_cache,
value=value,
block_table=block_tables,
atten_mask=attn_mask,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads,
num_heads=num_heads,
scale=scale,
sparse_mode=3,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -446,10 +429,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
)
def update_graph_params_workspaces(num_tokens: int, workspace: Any):
def update_graph_params_workspaces(num_tokens: int, workspace: int):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = workspace
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
def get_graph_params():

View File

@@ -233,7 +233,8 @@ class NPUPlatform(Platform):
"vllm.mla_forward"
])
update_aclgraph_sizes(vllm_config)
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
@@ -270,7 +271,8 @@ class NPUPlatform(Platform):
compilation_config.use_inductor = False
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
update_aclgraph_sizes(vllm_config)
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")

View File

@@ -21,6 +21,7 @@ import math
import types
from typing import Any, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
@@ -31,7 +32,6 @@ from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
import numpy as np
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform

View File

@@ -331,6 +331,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_groups: list[list[AttentionGroup]] = []
self.encoder_cache: Dict[str, torch.Tensor] = {}
self.attn_mask = None
self.fia_attn_mask = None
self.attn_state = None
self.requests: Dict[str, CachedRequestState] = {}
self.intermediate_tensors: Optional[IntermediateTensors] = None
@@ -1030,6 +1031,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
return None
def _make_fia_attention_mask(self) -> torch.Tensor:
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
return self.attn_mask_builder.get_splitfuse_attn_mask()
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0
for index, req_id in enumerate(self.input_batch.req_ids):
@@ -1667,6 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
self.fia_attn_mask = self._make_fia_attention_mask()
self.attn_state = attn_state # type: ignore
self.with_prefill = with_prefill
@@ -1899,6 +1906,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
@@ -2756,13 +2764,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
self.device).to(torch.int32)
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
self.query_start_loc[1:num_reqs + 1] = torch.Tensor(cu_num_tokens)
self.query_start_loc_cpu[1:num_reqs +
1] = torch.Tensor(cu_num_tokens)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
assigned_mask_dim = 2048
self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim,
assigned_mask_dim),
diagonal=1).to(torch.int8).to(
self.device)
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -2805,6 +2818,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
max_query_len=max_query_len,
@@ -3978,10 +3992,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
graph_support = None
if hasattr(builder, 'aclgraph_support'):
graph_support = builder.aclgraph_support.value
builder_aclgraph = builder.aclgraph_support
else:
graph_support = builder.cudagraph_support.value
builder_aclgraph = builder.cudagraph_support
if graph_support < min_ag_support.value:
min_ag_support = builder.aclgraph_support
min_ag_support = builder_aclgraph
min_ag_builder_name = builder.__class__.__name__
# This is an imitation of compilation_config.splitting_ops_contain_attention()