support FULL graph mode for GQA (#3970)

### What this PR does / why we need it?
The current library only supports the FullDecodeOnly graph mode, which
enables full graph execution during the decode. This PR extends support
to allow full graph execution in both the prefill and decode, referred
to as FULL graph mode.

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-11-17 10:50:35 +08:00
committed by GitHub
parent c334114f69
commit e38ef2c434
11 changed files with 328 additions and 296 deletions

View File

@@ -180,6 +180,7 @@ jobs:
if: ${{ inputs.type == 'full' }} if: ${{ inputs.type == 'full' }}
run: | run: |
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/test_expert_parallel.py pytest -sv tests/e2e/multicard/test_expert_parallel.py
pytest -sv tests/e2e/multicard/test_external_launcher.py pytest -sv tests/e2e/multicard/test_external_launcher.py

View File

@@ -29,7 +29,7 @@ from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal from tests.e2e.model_utils import check_outputs_equal
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY():
if 'HCCL_OP_EXPANSION_MODE' in os.environ: if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE'] del os.environ['HCCL_OP_EXPANSION_MODE']
prompts = [ prompts = [
@@ -42,15 +42,64 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
max_model_len=1024, max_model_len=1024,
tensor_parallel_size=2, tensor_parallel_size=2,
enforce_eager=False, enforce_eager=False,
compilation_config={"cudagraph_mode": compilation_config={
"FULL_DECODE_ONLY"}) as runner: "cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
}) as runner:
vllm_fullgraph_outputs = runner.model.generate(prompts, vllm_fullgraph_outputs = runner.model.generate(prompts,
sampling_params) sampling_params)
with VllmRunner( with VllmRunner(
model, model,
max_model_len=1024, max_model_len=1024,
enforce_eager=True, tensor_parallel_size=2,
enforce_eager=False,
) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
vllm_fullgraph_outputs_list = []
for output in vllm_fullgraph_outputs:
vllm_fullgraph_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
vllm_eager_outputs_list = []
for output in vllm_eager_outputs:
vllm_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_fullgraph_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_fullgraph_outputs",
)
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
model = "Qwen/Qwen3-30B-A3B"
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
with VllmRunner(model,
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=False,
compilation_config={
"cudagraph_mode": "FULL",
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
}) as runner:
vllm_fullgraph_outputs = runner.model.generate(prompts,
sampling_params)
with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=False,
) as runner: ) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params) vllm_eager_outputs = runner.model.generate(prompts, sampling_params)

View File

@@ -46,7 +46,7 @@ def mtp_correctness(sampling_config: SamplingParams,
graph_mode_str = "PIECEWISE" graph_mode_str = "PIECEWISE"
if graph_mode == CUDAGraphMode.FULL: if graph_mode == CUDAGraphMode.FULL:
graph_mode_str = "FULL" graph_mode_str = "FULL_DECODE_ONLY"
with VllmRunner( with VllmRunner(
model_name, model_name,
@@ -63,7 +63,9 @@ def mtp_correctness(sampling_config: SamplingParams,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_model_len=2000, max_model_len=2000,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
cudagraph_mode=graph_mode_str), cudagraph_mode=graph_mode_str,
cudagraph_capture_sizes=[12],
),
additional_config={"ascend_scheduler_config": { additional_config={"ascend_scheduler_config": {
"enabled": False "enabled": False
}}) as spec_llm: }}) as spec_llm:

View File

@@ -286,10 +286,12 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8 * 64) assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_flash_attention') @patch('torch_npu._npu_flash_attention')
def test_forward_prefill_no_cache(self, mock_flash_attention, def test_forward_prefill_no_cache(self, mock_flash_attention,
mock_reshape_cache): mock_reshape_cache,
mock_get_forward_context):
"""Test forward pass in PrefillNoCache state""" """Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8 * 64) query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64)
@@ -297,6 +299,8 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 64) kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query) output = torch.empty_like(query)
mock_get_forward_context.return_value = MagicMock(capturing=False)
metadata = self.attn_metadata metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillNoCache metadata.attn_state = AscendAttentionState.PrefillNoCache
metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.attn_mask = torch.randn(1, 1, 10, 10)
@@ -316,7 +320,8 @@ class TestAscendAttentionBackendImpl(TestBase):
@patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_prefill_cache_hit(self, @patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_prefill_cache_hit(self, mock_get_forward_context,
mock_npu_fused_infer_attention_score, mock_npu_fused_infer_attention_score,
mock_npu_reshape_and_cache): mock_npu_reshape_and_cache):
"""Test forward pass in PrefillCacheHit state""" """Test forward pass in PrefillCacheHit state"""
@@ -326,8 +331,6 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 64) kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query) output = torch.empty_like(query)
mock_npu_fused_infer_attention_score.return_value = (output, 1)
metadata = self.attn_metadata metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillCacheHit metadata.attn_state = AscendAttentionState.PrefillCacheHit
metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.attn_mask = torch.randn(1, 1, 10, 10)
@@ -340,18 +343,23 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.num_prefills = 10 metadata.num_prefills = 10
layer = self.layer_no_quant layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_fused_infer_attention_score.return_value = (output,
torch.ones(
10, 8, 64))
output = self.impl.forward(layer, query, key, value, kv_cache, output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output) metadata, output)
mock_npu_fused_infer_attention_score.assert_called_once() mock_npu_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64) assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention') @patch('torch_npu._npu_paged_attention')
def test_forward_decode_only(self, mock_paged_attention, @patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_decode_only(self, mock_get_forward_context,
mock_npu_reshape_and_cache, mock_npu_reshape_and_cache,
mock_get_forward_context): mock_paged_attention):
"""Test forward pass in DecodeOnly state""" """Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64) query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64)
@@ -378,115 +386,11 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8 * 64) assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('vllm_ascend.attention.attention_v1.get_graph_params')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
@patch('torch.npu.graph_task_group_end')
@patch('torch.npu.graph_task_group_begin')
@patch('torch.npu.ExternalEvent')
@patch('torch_npu.npu.current_stream')
@patch('vllm_ascend.attention.attention_v1.weak_ref_tensors')
def test_paged_attention_with_existing_workspace(
self,
mock_get_forward_context,
mock_get_graph_params,
mock_npu_reshape_and_cache,
mock_paged_attention,
mock_graph_begin,
mock_graph_end,
mock_external_event_class,
mock_current_stream,
mock_weak_ref_tensors,
):
graph_params = MagicMock()
attn_metadata = MagicMock()
num_tokens = 10
graph_params.workspaces = {num_tokens: 10}
graph_params.events = {num_tokens: []}
graph_params.attn_params = {num_tokens: []}
graph_params.handles = {num_tokens: []}
query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size]
key_cache = MagicMock()
value_cache = MagicMock()
num_kv_heads = 4
num_heads = 8
scale = 0.1
output = torch.randn(2, 5, 8)
self_obj = MagicMock()
self_obj.key_cache = key_cache
self_obj.value_cache = value_cache
self_obj.num_kv_heads = num_kv_heads
self_obj.num_heads = num_heads
self_obj.scale = scale
mock_stream = MagicMock()
mock_current_stream.return_value = mock_stream
mock_event_instance = MagicMock()
mock_external_event_class.return_value = mock_event_instance
mock_handle = MagicMock()
mock_graph_end.return_value = mock_handle
workspace = graph_params.workspaces.get(num_tokens)
self.assertEqual(workspace, 10)
weak_ref_tensors = MagicMock(side_effect=lambda x: x)
# 2. Handle graph capturing mode
stream = mock_current_stream()
event = mock_external_event_class()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query),
weak_ref_tensors(self_obj.key_cache),
weak_ref_tensors(self_obj.value_cache),
self_obj.num_kv_heads,
self_obj.num_heads,
self_obj.scale,
weak_ref_tensors(attn_metadata.block_tables),
attn_metadata.seq_lens,
output,
))
mock_event_instance.wait.assert_called_once_with(mock_stream)
mock_event_instance.reset.assert_called_once_with(mock_stream)
self.assertEqual(len(graph_params.events[num_tokens]), 1)
self.assertEqual(len(graph_params.attn_params[num_tokens]), 1)
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=True)
mock_get_graph_params.return_value = graph_params
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output)
mock_paged_attention.assert_called_once()
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score, @patch('torch_npu._npu_reshape_and_cache')
mock_npu_reshape_and_cache): def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
mock_fused_infer_attention_score,
mock_get_forward_context):
"""Test forward pass in DecodeOnly state""" """Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64) query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64)
@@ -494,6 +398,8 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 64) kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty(10, 8, 64) output = torch.empty(10, 8, 64)
mock_get_forward_context.return_value = MagicMock(capturing=False)
metadata = self.attn_metadata metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10] * 10) metadata.seq_lens = torch.tensor([10] * 10)
@@ -512,12 +418,12 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8, 64) assert output.shape == (10, 8, 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context') @patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention') @patch('torch_npu._npu_paged_attention')
@patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu.npu_fused_infer_attention_score')
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_decode_only_swa_seq_len_mismatch( def test_forward_decode_only_swa_seq_len_mismatch(
self, mock_fused_infer_attention_score, mock_paged_attention, self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
mock_npu_reshape_and_cache, mock_get_forward_context): mock_paged_attention, mock_get_forward_context):
"""Test forward pass in DecodeOnly state when seq)len_mismatch""" """Test forward pass in DecodeOnly state when seq)len_mismatch"""
query = torch.randn(10, 8 * 64) query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64)
@@ -535,11 +441,11 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.num_decodes = 10 metadata.num_decodes = 10
metadata.num_prefills = 0 metadata.num_prefills = 0
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
mock_get_forward_context.return_value = MagicMock(capturing=False) mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, 64),
torch.ones(10, 8, 64))
output = self.impl_swa.forward(layer, query, key, value, kv_cache, output = self.impl_swa.forward(layer, query, key, value, kv_cache,
metadata, output) metadata, output)
@@ -548,11 +454,13 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8 * 64) assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False) @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
@patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill') @patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
def test_forward_head_size_192(self, mock_vanilla_prefill, def test_forward_head_size_192(self, mock_vanilla_prefill,
mock_npu_reshape_and_cache, mock_is_310p): mock_npu_reshape_and_cache, mock_is_310p,
mock_get_forward_context):
"""Test forward pass when head_size is 192""" """Test forward pass when head_size is 192"""
self.impl.head_size = 192 self.impl.head_size = 192
@@ -562,6 +470,8 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 192) kv_cache = torch.empty(2, 5, 128, 8, 192)
output = torch.empty_like(query) output = torch.empty_like(query)
mock_get_forward_context.return_value = MagicMock(capturing=False)
metadata = self.attn_metadata metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10]) metadata.query_lens = torch.tensor([10])
@@ -580,11 +490,12 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_vanilla_prefill.assert_called_once() mock_vanilla_prefill.assert_called_once()
assert output.shape == (10, 8 * 192) assert output.shape == (10, 8 * 192)
@patch('torch_npu._npu_reshape_and_cache') @patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_normal_v1_situation(self, @patch('torch_npu._npu_reshape_and_cache')
def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
mock_npu_fused_infer_attention_score, mock_npu_fused_infer_attention_score,
mock_npu_reshape_and_cache): mock_get_forward_context):
"""Test forward pass in normal V1 situation""" """Test forward pass in normal V1 situation"""
query = torch.randn(10, 8 * 64) query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64) key = torch.randn(10, 8 * 64)
@@ -592,8 +503,6 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 64) kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query) output = torch.empty_like(query)
mock_npu_fused_infer_attention_score.return_value = (output, 1)
metadata = self.attn_metadata metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10]) metadata.query_lens = torch.tensor([10])
@@ -604,6 +513,10 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.num_decodes = 0 metadata.num_decodes = 0
metadata.num_prefills = 10 metadata.num_prefills = 10
layer = self.layer_no_quant layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_fused_infer_attention_score.return_value = (output,
torch.ones(
10, 8, 64))
output = self.impl.forward(layer, query, key, value, kv_cache, output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output) metadata, output)
@@ -615,7 +528,8 @@ class TestAscendAttentionBackendImpl(TestBase):
@patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu.npu_fused_infer_attention_score')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True) @patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
def test_forward_310p_device(self, mock_is_310p, @patch('vllm_ascend.attention.attention_v1.get_forward_context')
def test_forward_310p_device(self, mock_get_forward_context, mock_is_310p,
mock_npu_fused_infer_attention_score, mock_npu_fused_infer_attention_score,
mock_npu_reshape_and_cache, mock_npu_reshape_and_cache,
mock_npu_format_cast): mock_npu_format_cast):
@@ -626,8 +540,6 @@ class TestAscendAttentionBackendImpl(TestBase):
kv_cache = torch.empty(2, 5, 128, 8, 64) kv_cache = torch.empty(2, 5, 128, 8, 64)
output = torch.empty_like(query) output = torch.empty_like(query)
mock_npu_fused_infer_attention_score.return_value = (output, 1)
metadata = self.attn_metadata metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10) metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10]) metadata.query_lens = torch.tensor([10])
@@ -641,6 +553,11 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_npu_format_cast.return_value = metadata.attn_mask mock_npu_format_cast.return_value = metadata.attn_mask
mock_get_forward_context.return_value = MagicMock(capturing=False)
mock_npu_fused_infer_attention_score.return_value = (output,
torch.ones(
10, 8, 64))
output = self.impl.forward(layer, query, key, value, kv_cache, output = self.impl.forward(layer, query, key, value, kv_cache,
metadata, output) metadata, output)

View File

@@ -195,6 +195,7 @@ class AscendMetadataForDecode:
class AscendMetadata: class AscendMetadata:
# **************************** Basic Properties ************************** # # **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None attn_mask: Optional[torch.Tensor] = None
fia_attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run. # Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
@@ -215,6 +216,7 @@ class AscendMetadata:
seq_lens: torch.Tensor = None seq_lens: torch.Tensor = None
seq_lens_list: List[int] = None # type: ignore seq_lens_list: List[int] = None # type: ignore
actual_seq_lengths_q: List[int] = None # type: ignore actual_seq_lengths_q: List[int] = None # type: ignore
query_start_loc_list: List[int] = None # type: ignore
query_start_loc: torch.Tensor = None query_start_loc: torch.Tensor = None
query_lens: torch.Tensor = None query_lens: torch.Tensor = None
@@ -241,7 +243,8 @@ class AscendMetadata:
class AscendAttentionMetadataBuilder: class AscendAttentionMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no). # Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \ aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE AttentionCGSupport.ALWAYS
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch? # Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query # If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch. # length that will be pulled into the front of the batch.
@@ -321,6 +324,7 @@ class AscendAttentionMetadataBuilder:
num_actual_tokens_pcp_padded] num_actual_tokens_pcp_padded]
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] # slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask attn_mask = common_attn_metadata.attn_mask
fia_attn_mask = common_attn_metadata.fia_attn_mask
attn_state = common_attn_metadata.attn_state attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs num_reqs
@@ -471,6 +475,7 @@ class AscendAttentionMetadataBuilder:
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table, block_tables=block_table,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
query_lens=query_lens, query_lens=query_lens,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(), seq_lens_list=seq_lens.tolist(),
@@ -478,6 +483,7 @@ class AscendAttentionMetadataBuilder:
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
attn_mask=attn_mask, attn_mask=attn_mask,
fia_attn_mask=fia_attn_mask,
attn_state=attn_state, attn_state=attn_state,
num_prefills=num_prefills, num_prefills=num_prefills,
num_decodes=num_decodes, num_decodes=num_decodes,
@@ -565,6 +571,113 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.dcp_group = get_dcp_group( self.dcp_group = get_dcp_group(
).device_group if self.dcp_size > 1 else None ).device_group if self.dcp_size > 1 else None
def full_graph_attention(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
num_tokens=0):
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
actual_seq_lengths_kv = attn_metadata.seq_lens_list
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
# Normal V1 situation.
else:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
num_tokens = attn_metadata.query_start_loc_list[-1]
query = query[:num_tokens]
graph_params = get_graph_params()
query_start_loc = attn_metadata.query_start_loc_list
# Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.fia_attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
sparse_mode=3,
scale=self.scale,
)
update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(query), weak_ref_tensors(key),
weak_ref_tensors(value), weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.fia_attn_mask), block_size,
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
self.num_heads, self.scale, weak_ref_tensors(output),
weak_ref_tensors(softmax_lse)))
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.fia_attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
workspace=workspace,
out=[output, softmax_lse],
)
output = output.view(num_tokens, self.num_heads, self.head_size)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return output, num_tokens
def _forward_prefill_no_cache( def _forward_prefill_no_cache(
self, self,
query: torch.Tensor, query: torch.Tensor,
@@ -692,70 +805,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
output = output.view(batch_size, self.num_heads, self.head_size) output = output.view(batch_size, self.num_heads, self.head_size)
else: else:
graph_params = get_graph_params() torch_npu._npu_paged_attention(
forward_context: ForwardContext = get_forward_context() query=query,
num_tokens = query.shape[0] key_cache=self.key_cache,
if forward_context.capturing: value_cache=self.value_cache,
# Get workspace from cache or calculate it if not present. num_kv_heads=self.num_kv_heads,
workspace = graph_params.workspaces.get(num_tokens) num_heads=self.num_heads,
if workspace is None: scale_value=self.scale,
workspace = torch_npu._npu_paged_attention_get_workspace( block_table=attn_metadata.block_tables,
query=query, context_lens=attn_metadata.seq_lens,
key_cache=self.key_cache, out=output)
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query),
weak_ref_tensors(self.key_cache),
weak_ref_tensors(self.value_cache),
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
))
torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
return output return output
def _forward_v1_style( def _forward_v1_style(
@@ -819,7 +878,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
scale=self.scale, scale=self.scale,
sparse_mode=3, sparse_mode=3,
) )
return output return output
def _attention_with_nomask_and_mask(self, q: torch.Tensor, def _attention_with_nomask_and_mask(self, q: torch.Tensor,
@@ -1481,47 +1539,51 @@ class AscendAttentionBackendImpl(AttentionImpl):
num_decode_tokens:attn_metadata. num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded]) num_actual_tokens_pcp_padded])
if self.pcp_size * self.dcp_size > 1: forward_context: ForwardContext = get_forward_context()
intermediate_output = self._forward_pcp_dcp( if not forward_context.capturing:
query, key, value, kv_cache, attn_metadata, output) if self.pcp_size * self.dcp_size > 1:
elif attn_type == AttentionType.ENCODER_ONLY: intermediate_output = self._forward_pcp_dcp(
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. query, key, value, kv_cache, attn_metadata, output)
cum_seq_len = attn_metadata.query_start_loc[1:].tolist() elif attn_type == AttentionType.ENCODER_ONLY:
intermediate_output = torch_npu.npu_fusion_attention( # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
query, cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
key, intermediate_output = torch_npu.npu_fusion_attention(
value, query,
head_num=self.num_heads, key,
input_layout="TND", value,
scale=self.scale, head_num=self.num_heads,
sparse_mode=4, input_layout="TND",
atten_mask=attn_metadata.attn_mask, scale=self.scale,
pre_tockens=attn_metadata.max_query_len, sparse_mode=4,
next_tockens=attn_metadata.max_query_len, atten_mask=attn_metadata.attn_mask,
actual_seq_qlen=cum_seq_len, pre_tockens=attn_metadata.max_query_len,
actual_seq_kvlen=cum_seq_len, next_tockens=attn_metadata.max_query_len,
)[0] actual_seq_qlen=cum_seq_len,
# V0-Style scheduler situation. actual_seq_kvlen=cum_seq_len,
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: )[0]
intermediate_output = self._forward_prefill_no_cache( # V0-Style scheduler situation.
query, key, value, attn_metadata, output, num_tokens) elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
elif attn_metadata.attn_state == \ intermediate_output = self._forward_prefill_no_cache(
AscendAttentionState.PrefillCacheHit: query, key, value, attn_metadata, output, num_tokens)
intermediate_output = self._forward_prefill_cache_hit( elif attn_metadata.attn_state == \
query, attn_metadata, output) AscendAttentionState.PrefillCacheHit:
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: intermediate_output = self._forward_prefill_cache_hit(
intermediate_output = self._forward_decode_only( query, attn_metadata, output)
query, attn_metadata, output) elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
# Normal V1 situation. intermediate_output = self._forward_decode_only(
query, attn_metadata, output)
# Normal V1 situation.
else:
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
intermediate_output = self._forward_v1_style(
query, attn_metadata, output)
else: else:
# npu_fused_infer_attention_score does not support cases intermediate_output, num_tokens = self.full_graph_attention(
# where query.shape[0] != attn_metadata.query_start_loc[-1]. query, key, value, attn_metadata, output)
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
intermediate_output = self._forward_v1_style(
query, attn_metadata, output)
output[:num_tokens] = intermediate_output[:num_tokens] output[:num_tokens] = intermediate_output[:num_tokens]
return output return output

View File

@@ -1278,8 +1278,7 @@ class AscendMLAImpl(MLAAttentionImpl):
if workspace is None: if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, k_nope, **common_kwargs) q_nope, k_nope, k_nope, **common_kwargs)
update_graph_params_workspaces(num_tokens, update_graph_params_workspaces(num_tokens, workspace)
weak_ref_tensors(workspace))
attn_output = torch.empty_like(q_nope) attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty(num_tokens, softmax_lse = torch.empty(num_tokens,
@@ -1779,8 +1778,7 @@ class AscendMLAImpl(MLAAttentionImpl):
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table, q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
seq_len, num_heads, self.scale, self.num_kv_heads, seq_len, num_heads, self.scale, self.num_kv_heads,
**common_kwargs) **common_kwargs)
update_graph_params_workspaces(num_tokens, update_graph_params_workspaces(num_tokens, workspace)
weak_ref_tensors(workspace))
attn_output = torch.empty_like(q_nope) attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1), softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype, dtype=q_nope.dtype,

View File

@@ -88,6 +88,8 @@ class AscendCommonAttentionMetadata:
attn_mask: torch.Tensor = None attn_mask: torch.Tensor = None
fia_attn_mask: torch.Tensor = None
spec_attn_mask: torch.Tensor = None spec_attn_mask: torch.Tensor = None
attn_state: Any = None attn_state: Any = None

View File

@@ -203,48 +203,31 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params.handles[runtime_shape], graph_params.handles[runtime_shape],
graph_params.events[runtime_shape], graph_params.events[runtime_shape],
): ):
( (query, key_cache, value, block_tables, attn_mask, block_size,
query, seq_lens, query_start_loc, num_kv_heads, num_heads, scale,
key_cache, attn_output, softmax_lse) = param
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY seq_lens = forward_context.attn_metadata[key].seq_lens_list
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention query_start_loc = forward_context.attn_metadata[
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens key].query_start_loc_list
# might encounter a bigger workspace, while currently we use max_model_len to
# calculate max workspace in capturing. So additional get_workspace is added
# here to avoid such bugs.
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query, torch_npu.npu_fused_infer_attention_score.out(
key_cache=key_cache, query=query,
value_cache=value_cache, key=key_cache,
num_kv_heads=num_kv_heads, value=value,
num_heads=num_heads, block_table=block_tables,
scale_value=scale, atten_mask=attn_mask,
block_table=block_table, input_layout="TND",
context_lens=seq_lens, block_size=block_size,
out=output, actual_seq_lengths=query_start_loc,
workspace=workspace) actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads,
num_heads=num_heads,
scale=scale,
sparse_mode=3,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream) torch.npu.graph_task_update_end(update_stream)
event.record(update_stream) event.record(update_stream)
@@ -446,10 +429,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
) )
def update_graph_params_workspaces(num_tokens: int, workspace: Any): def update_graph_params_workspaces(num_tokens: int, workspace: int):
global _graph_params global _graph_params
if _graph_params is not None: if _graph_params is not None:
_graph_params.workspaces[num_tokens] = workspace _graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
def get_graph_params(): def get_graph_params():

View File

@@ -233,7 +233,8 @@ class NPUPlatform(Platform):
"vllm.mla_forward" "vllm.mla_forward"
]) ])
update_aclgraph_sizes(vllm_config) update_aclgraph_sizes(vllm_config)
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
logger.info( logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode") "using only ACL Graph mode")
@@ -270,7 +271,8 @@ class NPUPlatform(Platform):
compilation_config.use_inductor = False compilation_config.use_inductor = False
compilation_config.splitting_ops.extend(["vllm::mla_forward"]) compilation_config.splitting_ops.extend(["vllm::mla_forward"])
update_aclgraph_sizes(vllm_config) update_aclgraph_sizes(vllm_config)
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
logger.info( logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode") "using only ACL Graph mode")

View File

@@ -21,6 +21,7 @@ import math
import types import types
from typing import Any, Optional from typing import Any, Optional
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
@@ -31,7 +32,6 @@ from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import logger from vllm.logger import logger
import numpy as np
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform

View File

@@ -331,6 +331,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_groups: list[list[AttentionGroup]] = [] self.attn_groups: list[list[AttentionGroup]] = []
self.encoder_cache: Dict[str, torch.Tensor] = {} self.encoder_cache: Dict[str, torch.Tensor] = {}
self.attn_mask = None self.attn_mask = None
self.fia_attn_mask = None
self.attn_state = None self.attn_state = None
self.requests: Dict[str, CachedRequestState] = {} self.requests: Dict[str, CachedRequestState] = {}
self.intermediate_tensors: Optional[IntermediateTensors] = None self.intermediate_tensors: Optional[IntermediateTensors] = None
@@ -1030,6 +1031,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else: else:
return None return None
def _make_fia_attention_mask(self) -> torch.Tensor:
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
return self.attn_mask_builder.get_splitfuse_attn_mask()
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0 mrope_pos_ptr = 0
for index, req_id in enumerate(self.input_batch.req_ids): for index, req_id in enumerate(self.input_batch.req_ids):
@@ -1667,6 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu, position=positions_cpu,
attn_state=attn_state) attn_state=attn_state)
self.fia_attn_mask = self._make_fia_attention_mask()
self.attn_state = attn_state # type: ignore self.attn_state = attn_state # type: ignore
self.with_prefill = with_prefill self.with_prefill = with_prefill
@@ -1899,6 +1906,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens_cpu=num_computed_tokens_cpu, num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions, positions=self.positions,
attn_mask=self.attn_mask, attn_mask=self.attn_mask,
fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask, spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state, attn_state=self.attn_state,
is_only_prefill=bool(np.all(num_valid_tokens != 1)), is_only_prefill=bool(np.all(num_valid_tokens != 1)),
@@ -2756,13 +2764,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens, arange = self._get_cumsum_and_arange( cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens) num_scheduled_tokens)
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
self.device).to(torch.int32) self.query_start_loc[1:num_reqs + 1] = torch.Tensor(cu_num_tokens)
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
self.query_start_loc_cpu[1:num_reqs + self.query_start_loc_cpu[1:num_reqs +
1] = torch.Tensor(cu_num_tokens) 1] = torch.Tensor(cu_num_tokens)
self.query_lens = torch.from_numpy(num_scheduled_tokens) self.query_lens = torch.from_numpy(num_scheduled_tokens)
assigned_mask_dim = 2048
self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim,
assigned_mask_dim),
diagonal=1).to(torch.int8).to(
self.device)
num_computed_tokens_cpu = ( num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -2805,6 +2818,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens_cpu=num_computed_tokens_cpu, num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions, positions=self.positions,
attn_mask=self.attn_mask, attn_mask=self.attn_mask,
fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask, spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state, attn_state=self.attn_state,
max_query_len=max_query_len, max_query_len=max_query_len,
@@ -3978,10 +3992,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
graph_support = None graph_support = None
if hasattr(builder, 'aclgraph_support'): if hasattr(builder, 'aclgraph_support'):
graph_support = builder.aclgraph_support.value graph_support = builder.aclgraph_support.value
builder_aclgraph = builder.aclgraph_support
else: else:
graph_support = builder.cudagraph_support.value graph_support = builder.cudagraph_support.value
builder_aclgraph = builder.cudagraph_support
if graph_support < min_ag_support.value: if graph_support < min_ag_support.value:
min_ag_support = builder.aclgraph_support min_ag_support = builder_aclgraph
min_ag_builder_name = builder.__class__.__name__ min_ag_builder_name = builder.__class__.__name__
# This is an imitation of compilation_config.splitting_ops_contain_attention() # This is an imitation of compilation_config.splitting_ops_contain_attention()