support FULL graph mode for GQA (#3970)
### What this PR does / why we need it?
The current library only supports the FullDecodeOnly graph mode, which
enables full graph execution during the decode. This PR extends support
to allow full graph execution in both the prefill and decode, referred
to as FULL graph mode.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -180,6 +180,7 @@ jobs:
|
|||||||
if: ${{ inputs.type == 'full' }}
|
if: ${{ inputs.type == 'full' }}
|
||||||
run: |
|
run: |
|
||||||
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
|
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
|
||||||
|
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
|
||||||
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
||||||
pytest -sv tests/e2e/multicard/test_expert_parallel.py
|
pytest -sv tests/e2e/multicard/test_expert_parallel.py
|
||||||
pytest -sv tests/e2e/multicard/test_external_launcher.py
|
pytest -sv tests/e2e/multicard/test_external_launcher.py
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from tests.e2e.conftest import VllmRunner
|
|||||||
from tests.e2e.model_utils import check_outputs_equal
|
from tests.e2e.model_utils import check_outputs_equal
|
||||||
|
|
||||||
|
|
||||||
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
|
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY():
|
||||||
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
||||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||||
prompts = [
|
prompts = [
|
||||||
@@ -42,15 +42,64 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
|
|||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
compilation_config={"cudagraph_mode":
|
compilation_config={
|
||||||
"FULL_DECODE_ONLY"}) as runner:
|
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||||
|
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
|
||||||
|
}) as runner:
|
||||||
vllm_fullgraph_outputs = runner.model.generate(prompts,
|
vllm_fullgraph_outputs = runner.model.generate(prompts,
|
||||||
sampling_params)
|
sampling_params)
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model,
|
model,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
enforce_eager=True,
|
tensor_parallel_size=2,
|
||||||
|
enforce_eager=False,
|
||||||
|
) as runner:
|
||||||
|
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
vllm_fullgraph_outputs_list = []
|
||||||
|
for output in vllm_fullgraph_outputs:
|
||||||
|
vllm_fullgraph_outputs_list.append(
|
||||||
|
(output.outputs[0].index, output.outputs[0].text))
|
||||||
|
|
||||||
|
vllm_eager_outputs_list = []
|
||||||
|
for output in vllm_eager_outputs:
|
||||||
|
vllm_eager_outputs_list.append(
|
||||||
|
(output.outputs[0].index, output.outputs[0].text))
|
||||||
|
|
||||||
|
check_outputs_equal(
|
||||||
|
outputs_0_lst=vllm_eager_outputs_list,
|
||||||
|
outputs_1_lst=vllm_fullgraph_outputs_list,
|
||||||
|
name_0="vllm_eager_outputs",
|
||||||
|
name_1="vllm_fullgraph_outputs",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL():
|
||||||
|
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
||||||
|
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is", "The president of the United States is",
|
||||||
|
"The capital of France is", "The future of AI is"
|
||||||
|
]
|
||||||
|
model = "Qwen/Qwen3-30B-A3B"
|
||||||
|
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
|
||||||
|
with VllmRunner(model,
|
||||||
|
max_model_len=1024,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
enforce_eager=False,
|
||||||
|
compilation_config={
|
||||||
|
"cudagraph_mode": "FULL",
|
||||||
|
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
|
||||||
|
}) as runner:
|
||||||
|
vllm_fullgraph_outputs = runner.model.generate(prompts,
|
||||||
|
sampling_params)
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
model,
|
||||||
|
max_model_len=1024,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
enforce_eager=False,
|
||||||
) as runner:
|
) as runner:
|
||||||
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def mtp_correctness(sampling_config: SamplingParams,
|
|||||||
|
|
||||||
graph_mode_str = "PIECEWISE"
|
graph_mode_str = "PIECEWISE"
|
||||||
if graph_mode == CUDAGraphMode.FULL:
|
if graph_mode == CUDAGraphMode.FULL:
|
||||||
graph_mode_str = "FULL"
|
graph_mode_str = "FULL_DECODE_ONLY"
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model_name,
|
model_name,
|
||||||
@@ -63,7 +63,9 @@ def mtp_correctness(sampling_config: SamplingParams,
|
|||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
max_model_len=2000,
|
max_model_len=2000,
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
cudagraph_mode=graph_mode_str),
|
cudagraph_mode=graph_mode_str,
|
||||||
|
cudagraph_capture_sizes=[12],
|
||||||
|
),
|
||||||
additional_config={"ascend_scheduler_config": {
|
additional_config={"ascend_scheduler_config": {
|
||||||
"enabled": False
|
"enabled": False
|
||||||
}}) as spec_llm:
|
}}) as spec_llm:
|
||||||
|
|||||||
@@ -286,10 +286,12 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
|
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
|
|
||||||
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('torch_npu._npu_flash_attention')
|
@patch('torch_npu._npu_flash_attention')
|
||||||
def test_forward_prefill_no_cache(self, mock_flash_attention,
|
def test_forward_prefill_no_cache(self, mock_flash_attention,
|
||||||
mock_reshape_cache):
|
mock_reshape_cache,
|
||||||
|
mock_get_forward_context):
|
||||||
"""Test forward pass in PrefillNoCache state"""
|
"""Test forward pass in PrefillNoCache state"""
|
||||||
query = torch.randn(10, 8 * 64)
|
query = torch.randn(10, 8 * 64)
|
||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
@@ -297,6 +299,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
@@ -316,7 +320,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
def test_forward_prefill_cache_hit(self,
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
|
def test_forward_prefill_cache_hit(self, mock_get_forward_context,
|
||||||
mock_npu_fused_infer_attention_score,
|
mock_npu_fused_infer_attention_score,
|
||||||
mock_npu_reshape_and_cache):
|
mock_npu_reshape_and_cache):
|
||||||
"""Test forward pass in PrefillCacheHit state"""
|
"""Test forward pass in PrefillCacheHit state"""
|
||||||
@@ -326,8 +331,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
mock_npu_fused_infer_attention_score.return_value = (output, 1)
|
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
@@ -340,18 +343,23 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
metadata.num_prefills = 10
|
metadata.num_prefills = 10
|
||||||
layer = self.layer_no_quant
|
layer = self.layer_no_quant
|
||||||
|
|
||||||
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||||
|
torch.ones(
|
||||||
|
10, 8, 64))
|
||||||
|
|
||||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||||
metadata, output)
|
metadata, output)
|
||||||
|
|
||||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
|
||||||
@patch('torch_npu._npu_paged_attention')
|
@patch('torch_npu._npu_paged_attention')
|
||||||
def test_forward_decode_only(self, mock_paged_attention,
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
|
def test_forward_decode_only(self, mock_get_forward_context,
|
||||||
mock_npu_reshape_and_cache,
|
mock_npu_reshape_and_cache,
|
||||||
mock_get_forward_context):
|
mock_paged_attention):
|
||||||
"""Test forward pass in DecodeOnly state"""
|
"""Test forward pass in DecodeOnly state"""
|
||||||
query = torch.randn(10, 8 * 64)
|
query = torch.randn(10, 8 * 64)
|
||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
@@ -378,115 +386,11 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_graph_params')
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
|
||||||
@patch('torch_npu._npu_paged_attention')
|
|
||||||
@patch('torch.npu.graph_task_group_end')
|
|
||||||
@patch('torch.npu.graph_task_group_begin')
|
|
||||||
@patch('torch.npu.ExternalEvent')
|
|
||||||
@patch('torch_npu.npu.current_stream')
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.weak_ref_tensors')
|
|
||||||
def test_paged_attention_with_existing_workspace(
|
|
||||||
self,
|
|
||||||
mock_get_forward_context,
|
|
||||||
mock_get_graph_params,
|
|
||||||
mock_npu_reshape_and_cache,
|
|
||||||
mock_paged_attention,
|
|
||||||
mock_graph_begin,
|
|
||||||
mock_graph_end,
|
|
||||||
mock_external_event_class,
|
|
||||||
mock_current_stream,
|
|
||||||
mock_weak_ref_tensors,
|
|
||||||
):
|
|
||||||
graph_params = MagicMock()
|
|
||||||
attn_metadata = MagicMock()
|
|
||||||
num_tokens = 10
|
|
||||||
|
|
||||||
graph_params.workspaces = {num_tokens: 10}
|
|
||||||
graph_params.events = {num_tokens: []}
|
|
||||||
graph_params.attn_params = {num_tokens: []}
|
|
||||||
graph_params.handles = {num_tokens: []}
|
|
||||||
|
|
||||||
query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size]
|
|
||||||
key_cache = MagicMock()
|
|
||||||
value_cache = MagicMock()
|
|
||||||
num_kv_heads = 4
|
|
||||||
num_heads = 8
|
|
||||||
scale = 0.1
|
|
||||||
output = torch.randn(2, 5, 8)
|
|
||||||
|
|
||||||
self_obj = MagicMock()
|
|
||||||
self_obj.key_cache = key_cache
|
|
||||||
self_obj.value_cache = value_cache
|
|
||||||
self_obj.num_kv_heads = num_kv_heads
|
|
||||||
self_obj.num_heads = num_heads
|
|
||||||
self_obj.scale = scale
|
|
||||||
|
|
||||||
mock_stream = MagicMock()
|
|
||||||
mock_current_stream.return_value = mock_stream
|
|
||||||
mock_event_instance = MagicMock()
|
|
||||||
mock_external_event_class.return_value = mock_event_instance
|
|
||||||
|
|
||||||
mock_handle = MagicMock()
|
|
||||||
mock_graph_end.return_value = mock_handle
|
|
||||||
|
|
||||||
workspace = graph_params.workspaces.get(num_tokens)
|
|
||||||
self.assertEqual(workspace, 10)
|
|
||||||
|
|
||||||
weak_ref_tensors = MagicMock(side_effect=lambda x: x)
|
|
||||||
|
|
||||||
# 2. Handle graph capturing mode
|
|
||||||
stream = mock_current_stream()
|
|
||||||
event = mock_external_event_class()
|
|
||||||
event.wait(stream)
|
|
||||||
event.reset(stream)
|
|
||||||
graph_params.events[num_tokens].append(event)
|
|
||||||
graph_params.attn_params[num_tokens].append((
|
|
||||||
weak_ref_tensors(query),
|
|
||||||
weak_ref_tensors(self_obj.key_cache),
|
|
||||||
weak_ref_tensors(self_obj.value_cache),
|
|
||||||
self_obj.num_kv_heads,
|
|
||||||
self_obj.num_heads,
|
|
||||||
self_obj.scale,
|
|
||||||
weak_ref_tensors(attn_metadata.block_tables),
|
|
||||||
attn_metadata.seq_lens,
|
|
||||||
output,
|
|
||||||
))
|
|
||||||
|
|
||||||
mock_event_instance.wait.assert_called_once_with(mock_stream)
|
|
||||||
mock_event_instance.reset.assert_called_once_with(mock_stream)
|
|
||||||
self.assertEqual(len(graph_params.events[num_tokens]), 1)
|
|
||||||
self.assertEqual(len(graph_params.attn_params[num_tokens]), 1)
|
|
||||||
|
|
||||||
query = torch.randn(10, 8 * 64)
|
|
||||||
key = torch.randn(10, 8 * 64)
|
|
||||||
value = torch.randn(10, 8 * 64)
|
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
|
||||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
|
||||||
metadata.seq_lens = torch.tensor([10])
|
|
||||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
|
||||||
metadata.num_actual_tokens = 10
|
|
||||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
|
||||||
metadata.num_decodes = 0
|
|
||||||
metadata.num_prefills = 10
|
|
||||||
layer = self.layer_no_quant
|
|
||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=True)
|
|
||||||
mock_get_graph_params.return_value = graph_params
|
|
||||||
|
|
||||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
|
||||||
metadata, output)
|
|
||||||
|
|
||||||
mock_paged_attention.assert_called_once()
|
|
||||||
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
|
|
||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
mock_npu_reshape_and_cache):
|
def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
|
||||||
|
mock_fused_infer_attention_score,
|
||||||
|
mock_get_forward_context):
|
||||||
"""Test forward pass in DecodeOnly state"""
|
"""Test forward pass in DecodeOnly state"""
|
||||||
query = torch.randn(10, 8 * 64)
|
query = torch.randn(10, 8 * 64)
|
||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
@@ -494,6 +398,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
output = torch.empty(10, 8, 64)
|
output = torch.empty(10, 8, 64)
|
||||||
|
|
||||||
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||||
metadata.seq_lens = torch.tensor([10] * 10)
|
metadata.seq_lens = torch.tensor([10] * 10)
|
||||||
@@ -512,12 +418,12 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
assert output.shape == (10, 8, 64)
|
assert output.shape == (10, 8, 64)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
|
||||||
@patch('torch_npu._npu_paged_attention')
|
@patch('torch_npu._npu_paged_attention')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
def test_forward_decode_only_swa_seq_len_mismatch(
|
def test_forward_decode_only_swa_seq_len_mismatch(
|
||||||
self, mock_fused_infer_attention_score, mock_paged_attention,
|
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
|
||||||
mock_npu_reshape_and_cache, mock_get_forward_context):
|
mock_paged_attention, mock_get_forward_context):
|
||||||
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
|
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
|
||||||
query = torch.randn(10, 8 * 64)
|
query = torch.randn(10, 8 * 64)
|
||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
@@ -535,11 +441,11 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
metadata.num_decodes = 10
|
metadata.num_decodes = 10
|
||||||
metadata.num_prefills = 0
|
metadata.num_prefills = 0
|
||||||
|
|
||||||
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
|
|
||||||
64), 1)
|
|
||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
|
||||||
|
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, 64),
|
||||||
|
torch.ones(10, 8, 64))
|
||||||
|
|
||||||
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
|
||||||
metadata, output)
|
metadata, output)
|
||||||
|
|
||||||
@@ -548,11 +454,13 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
|
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
|
|
||||||
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
|
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
|
||||||
def test_forward_head_size_192(self, mock_vanilla_prefill,
|
def test_forward_head_size_192(self, mock_vanilla_prefill,
|
||||||
mock_npu_reshape_and_cache, mock_is_310p):
|
mock_npu_reshape_and_cache, mock_is_310p,
|
||||||
|
mock_get_forward_context):
|
||||||
"""Test forward pass when head_size is 192"""
|
"""Test forward pass when head_size is 192"""
|
||||||
|
|
||||||
self.impl.head_size = 192
|
self.impl.head_size = 192
|
||||||
@@ -562,6 +470,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
metadata.query_lens = torch.tensor([10])
|
metadata.query_lens = torch.tensor([10])
|
||||||
@@ -580,11 +490,12 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
mock_vanilla_prefill.assert_called_once()
|
mock_vanilla_prefill.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 192)
|
assert output.shape == (10, 8 * 192)
|
||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
def test_forward_normal_v1_situation(self,
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
|
def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
|
||||||
mock_npu_fused_infer_attention_score,
|
mock_npu_fused_infer_attention_score,
|
||||||
mock_npu_reshape_and_cache):
|
mock_get_forward_context):
|
||||||
"""Test forward pass in normal V1 situation"""
|
"""Test forward pass in normal V1 situation"""
|
||||||
query = torch.randn(10, 8 * 64)
|
query = torch.randn(10, 8 * 64)
|
||||||
key = torch.randn(10, 8 * 64)
|
key = torch.randn(10, 8 * 64)
|
||||||
@@ -592,8 +503,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
mock_npu_fused_infer_attention_score.return_value = (output, 1)
|
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
metadata.query_lens = torch.tensor([10])
|
metadata.query_lens = torch.tensor([10])
|
||||||
@@ -604,6 +513,10 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
metadata.num_decodes = 0
|
metadata.num_decodes = 0
|
||||||
metadata.num_prefills = 10
|
metadata.num_prefills = 10
|
||||||
layer = self.layer_no_quant
|
layer = self.layer_no_quant
|
||||||
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||||
|
torch.ones(
|
||||||
|
10, 8, 64))
|
||||||
|
|
||||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||||
metadata, output)
|
metadata, output)
|
||||||
@@ -615,7 +528,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
|
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
|
||||||
def test_forward_310p_device(self, mock_is_310p,
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
|
def test_forward_310p_device(self, mock_get_forward_context, mock_is_310p,
|
||||||
mock_npu_fused_infer_attention_score,
|
mock_npu_fused_infer_attention_score,
|
||||||
mock_npu_reshape_and_cache,
|
mock_npu_reshape_and_cache,
|
||||||
mock_npu_format_cast):
|
mock_npu_format_cast):
|
||||||
@@ -626,8 +540,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
mock_npu_fused_infer_attention_score.return_value = (output, 1)
|
|
||||||
|
|
||||||
metadata = self.attn_metadata
|
metadata = self.attn_metadata
|
||||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||||
metadata.query_lens = torch.tensor([10])
|
metadata.query_lens = torch.tensor([10])
|
||||||
@@ -641,6 +553,11 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
|
|
||||||
mock_npu_format_cast.return_value = metadata.attn_mask
|
mock_npu_format_cast.return_value = metadata.attn_mask
|
||||||
|
|
||||||
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
mock_npu_fused_infer_attention_score.return_value = (output,
|
||||||
|
torch.ones(
|
||||||
|
10, 8, 64))
|
||||||
|
|
||||||
output = self.impl.forward(layer, query, key, value, kv_cache,
|
output = self.impl.forward(layer, query, key, value, kv_cache,
|
||||||
metadata, output)
|
metadata, output)
|
||||||
|
|
||||||
|
|||||||
@@ -195,6 +195,7 @@ class AscendMetadataForDecode:
|
|||||||
class AscendMetadata:
|
class AscendMetadata:
|
||||||
# **************************** Basic Properties ************************** #
|
# **************************** Basic Properties ************************** #
|
||||||
attn_mask: Optional[torch.Tensor] = None
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
|
fia_attn_mask: Optional[torch.Tensor] = None
|
||||||
# Current state of this attention run.
|
# Current state of this attention run.
|
||||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||||
|
|
||||||
@@ -215,6 +216,7 @@ class AscendMetadata:
|
|||||||
seq_lens: torch.Tensor = None
|
seq_lens: torch.Tensor = None
|
||||||
seq_lens_list: List[int] = None # type: ignore
|
seq_lens_list: List[int] = None # type: ignore
|
||||||
actual_seq_lengths_q: List[int] = None # type: ignore
|
actual_seq_lengths_q: List[int] = None # type: ignore
|
||||||
|
query_start_loc_list: List[int] = None # type: ignore
|
||||||
|
|
||||||
query_start_loc: torch.Tensor = None
|
query_start_loc: torch.Tensor = None
|
||||||
query_lens: torch.Tensor = None
|
query_lens: torch.Tensor = None
|
||||||
@@ -241,7 +243,8 @@ class AscendMetadata:
|
|||||||
class AscendAttentionMetadataBuilder:
|
class AscendAttentionMetadataBuilder:
|
||||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
AttentionCGSupport.ALWAYS
|
||||||
|
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||||
# Does this backend/builder reorder the batch?
|
# Does this backend/builder reorder the batch?
|
||||||
# If not, set this to None. Otherwise set it to the query
|
# If not, set this to None. Otherwise set it to the query
|
||||||
# length that will be pulled into the front of the batch.
|
# length that will be pulled into the front of the batch.
|
||||||
@@ -321,6 +324,7 @@ class AscendAttentionMetadataBuilder:
|
|||||||
num_actual_tokens_pcp_padded]
|
num_actual_tokens_pcp_padded]
|
||||||
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||||
attn_mask = common_attn_metadata.attn_mask
|
attn_mask = common_attn_metadata.attn_mask
|
||||||
|
fia_attn_mask = common_attn_metadata.fia_attn_mask
|
||||||
attn_state = common_attn_metadata.attn_state
|
attn_state = common_attn_metadata.attn_state
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||||
num_reqs
|
num_reqs
|
||||||
@@ -471,6 +475,7 @@ class AscendAttentionMetadataBuilder:
|
|||||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
|
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
|
||||||
query_lens=query_lens,
|
query_lens=query_lens,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_list=seq_lens.tolist(),
|
seq_lens_list=seq_lens.tolist(),
|
||||||
@@ -478,6 +483,7 @@ class AscendAttentionMetadataBuilder:
|
|||||||
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
|
fia_attn_mask=fia_attn_mask,
|
||||||
attn_state=attn_state,
|
attn_state=attn_state,
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
num_decodes=num_decodes,
|
num_decodes=num_decodes,
|
||||||
@@ -565,6 +571,113 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
self.dcp_group = get_dcp_group(
|
self.dcp_group = get_dcp_group(
|
||||||
).device_group if self.dcp_size > 1 else None
|
).device_group if self.dcp_size > 1 else None
|
||||||
|
|
||||||
|
def full_graph_attention(self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_metadata: AscendMetadata,
|
||||||
|
output: torch.Tensor,
|
||||||
|
num_tokens=0):
|
||||||
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
|
block_size = 128
|
||||||
|
block_table = None
|
||||||
|
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
||||||
|
elif attn_metadata.attn_state == \
|
||||||
|
AscendAttentionState.PrefillCacheHit:
|
||||||
|
batch_size = attn_metadata.query_lens.shape[0]
|
||||||
|
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||||
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||||
|
key = self.key_cache.view( # type: ignore
|
||||||
|
num_block, block_size, -1)
|
||||||
|
value = self.value_cache.view( # type: ignore
|
||||||
|
num_block, block_size, -1)
|
||||||
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||||
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||||
|
key = self.key_cache.view( # type: ignore
|
||||||
|
num_block, block_size, -1)
|
||||||
|
value = self.value_cache.view( # type: ignore
|
||||||
|
num_block, block_size, -1)
|
||||||
|
block_table = attn_metadata.block_tables
|
||||||
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||||
|
# Normal V1 situation.
|
||||||
|
else:
|
||||||
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||||
|
key = self.key_cache.view( # type: ignore
|
||||||
|
num_block, block_size, -1)
|
||||||
|
value = self.value_cache.view( # type: ignore
|
||||||
|
num_block, block_size, -1)
|
||||||
|
block_table = attn_metadata.block_tables
|
||||||
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||||
|
|
||||||
|
num_tokens = attn_metadata.query_start_loc_list[-1]
|
||||||
|
query = query[:num_tokens]
|
||||||
|
graph_params = get_graph_params()
|
||||||
|
query_start_loc = attn_metadata.query_start_loc_list
|
||||||
|
# Prepare tensors for attention output
|
||||||
|
# TODO: Refactor this to step-level instead of layer-level
|
||||||
|
|
||||||
|
# Get workspace from cache or calculate it if not present.
|
||||||
|
workspace = graph_params.workspaces.get(num_tokens)
|
||||||
|
softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device)
|
||||||
|
if workspace is None:
|
||||||
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
atten_mask=attn_metadata.fia_attn_mask,
|
||||||
|
block_table=block_table,
|
||||||
|
input_layout="TND",
|
||||||
|
block_size=block_size,
|
||||||
|
actual_seq_lengths=query_start_loc,
|
||||||
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||||
|
num_key_value_heads=self.num_kv_heads,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
sparse_mode=3,
|
||||||
|
scale=self.scale,
|
||||||
|
)
|
||||||
|
update_graph_params_workspaces(num_tokens, workspace)
|
||||||
|
|
||||||
|
# Handle graph capturing mode
|
||||||
|
stream = torch_npu.npu.current_stream()
|
||||||
|
|
||||||
|
event = torch.npu.ExternalEvent()
|
||||||
|
event.wait(stream)
|
||||||
|
event.reset(stream)
|
||||||
|
graph_params.events[num_tokens].append(event)
|
||||||
|
graph_params.attn_params[num_tokens].append(
|
||||||
|
(weak_ref_tensors(query), weak_ref_tensors(key),
|
||||||
|
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
||||||
|
weak_ref_tensors(attn_metadata.fia_attn_mask), block_size,
|
||||||
|
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
|
||||||
|
self.num_heads, self.scale, weak_ref_tensors(output),
|
||||||
|
weak_ref_tensors(softmax_lse)))
|
||||||
|
|
||||||
|
torch.npu.graph_task_group_begin(stream)
|
||||||
|
torch_npu.npu_fused_infer_attention_score.out(
|
||||||
|
query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
atten_mask=attn_metadata.fia_attn_mask,
|
||||||
|
block_table=block_table,
|
||||||
|
input_layout="TND",
|
||||||
|
block_size=block_size,
|
||||||
|
actual_seq_lengths=query_start_loc,
|
||||||
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||||
|
num_key_value_heads=self.num_kv_heads,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
scale=self.scale,
|
||||||
|
sparse_mode=3,
|
||||||
|
workspace=workspace,
|
||||||
|
out=[output, softmax_lse],
|
||||||
|
)
|
||||||
|
|
||||||
|
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
handle = torch.npu.graph_task_group_end(stream)
|
||||||
|
graph_params.handles[num_tokens].append(handle)
|
||||||
|
return output, num_tokens
|
||||||
|
|
||||||
def _forward_prefill_no_cache(
|
def _forward_prefill_no_cache(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -691,60 +804,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||||
|
|
||||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||||
else:
|
|
||||||
graph_params = get_graph_params()
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
|
||||||
num_tokens = query.shape[0]
|
|
||||||
if forward_context.capturing:
|
|
||||||
# Get workspace from cache or calculate it if not present.
|
|
||||||
workspace = graph_params.workspaces.get(num_tokens)
|
|
||||||
if workspace is None:
|
|
||||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
|
||||||
query=query,
|
|
||||||
key_cache=self.key_cache,
|
|
||||||
value_cache=self.value_cache,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale_value=self.scale,
|
|
||||||
block_table=attn_metadata.block_tables,
|
|
||||||
context_lens=attn_metadata.seq_lens,
|
|
||||||
out=output)
|
|
||||||
update_graph_params_workspaces(num_tokens,
|
|
||||||
weak_ref_tensors(workspace))
|
|
||||||
|
|
||||||
# Handle graph capturing mode
|
|
||||||
stream = torch_npu.npu.current_stream()
|
|
||||||
|
|
||||||
event = torch.npu.ExternalEvent()
|
|
||||||
event.wait(stream)
|
|
||||||
event.reset(stream)
|
|
||||||
graph_params.events[num_tokens].append(event)
|
|
||||||
graph_params.attn_params[num_tokens].append((
|
|
||||||
weak_ref_tensors(query),
|
|
||||||
weak_ref_tensors(self.key_cache),
|
|
||||||
weak_ref_tensors(self.value_cache),
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.num_heads,
|
|
||||||
self.scale,
|
|
||||||
attn_metadata.block_tables,
|
|
||||||
attn_metadata.seq_lens,
|
|
||||||
weak_ref_tensors(output),
|
|
||||||
))
|
|
||||||
|
|
||||||
torch.npu.graph_task_group_begin(stream)
|
|
||||||
torch_npu._npu_paged_attention(
|
|
||||||
query=query,
|
|
||||||
key_cache=self.key_cache,
|
|
||||||
value_cache=self.value_cache,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale_value=self.scale,
|
|
||||||
block_table=attn_metadata.block_tables,
|
|
||||||
context_lens=attn_metadata.seq_lens,
|
|
||||||
out=output,
|
|
||||||
workspace=workspace)
|
|
||||||
handle = torch.npu.graph_task_group_end(stream)
|
|
||||||
graph_params.handles[num_tokens].append(handle)
|
|
||||||
else:
|
else:
|
||||||
torch_npu._npu_paged_attention(
|
torch_npu._npu_paged_attention(
|
||||||
query=query,
|
query=query,
|
||||||
@@ -819,7 +878,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
sparse_mode=3,
|
sparse_mode=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
|
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
|
||||||
@@ -1481,6 +1539,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
num_decode_tokens:attn_metadata.
|
num_decode_tokens:attn_metadata.
|
||||||
num_actual_tokens_pcp_padded])
|
num_actual_tokens_pcp_padded])
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
if not forward_context.capturing:
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
intermediate_output = self._forward_pcp_dcp(
|
intermediate_output = self._forward_pcp_dcp(
|
||||||
query, key, value, kv_cache, attn_metadata, output)
|
query, key, value, kv_cache, attn_metadata, output)
|
||||||
@@ -1521,7 +1581,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
query = query[:num_tokens]
|
query = query[:num_tokens]
|
||||||
intermediate_output = self._forward_v1_style(
|
intermediate_output = self._forward_v1_style(
|
||||||
query, attn_metadata, output)
|
query, attn_metadata, output)
|
||||||
|
else:
|
||||||
|
intermediate_output, num_tokens = self.full_graph_attention(
|
||||||
|
query, key, value, attn_metadata, output)
|
||||||
output[:num_tokens] = intermediate_output[:num_tokens]
|
output[:num_tokens] = intermediate_output[:num_tokens]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -1278,8 +1278,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
if workspace is None:
|
if workspace is None:
|
||||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||||
q_nope, k_nope, k_nope, **common_kwargs)
|
q_nope, k_nope, k_nope, **common_kwargs)
|
||||||
update_graph_params_workspaces(num_tokens,
|
update_graph_params_workspaces(num_tokens, workspace)
|
||||||
weak_ref_tensors(workspace))
|
|
||||||
|
|
||||||
attn_output = torch.empty_like(q_nope)
|
attn_output = torch.empty_like(q_nope)
|
||||||
softmax_lse = torch.empty(num_tokens,
|
softmax_lse = torch.empty(num_tokens,
|
||||||
@@ -1779,8 +1778,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
|
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
|
||||||
seq_len, num_heads, self.scale, self.num_kv_heads,
|
seq_len, num_heads, self.scale, self.num_kv_heads,
|
||||||
**common_kwargs)
|
**common_kwargs)
|
||||||
update_graph_params_workspaces(num_tokens,
|
update_graph_params_workspaces(num_tokens, workspace)
|
||||||
weak_ref_tensors(workspace))
|
|
||||||
attn_output = torch.empty_like(q_nope)
|
attn_output = torch.empty_like(q_nope)
|
||||||
softmax_lse = torch.empty((num_tokens, num_heads, 1),
|
softmax_lse = torch.empty((num_tokens, num_heads, 1),
|
||||||
dtype=q_nope.dtype,
|
dtype=q_nope.dtype,
|
||||||
|
|||||||
@@ -88,6 +88,8 @@ class AscendCommonAttentionMetadata:
|
|||||||
|
|
||||||
attn_mask: torch.Tensor = None
|
attn_mask: torch.Tensor = None
|
||||||
|
|
||||||
|
fia_attn_mask: torch.Tensor = None
|
||||||
|
|
||||||
spec_attn_mask: torch.Tensor = None
|
spec_attn_mask: torch.Tensor = None
|
||||||
|
|
||||||
attn_state: Any = None
|
attn_state: Any = None
|
||||||
|
|||||||
@@ -203,48 +203,31 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
|||||||
graph_params.handles[runtime_shape],
|
graph_params.handles[runtime_shape],
|
||||||
graph_params.events[runtime_shape],
|
graph_params.events[runtime_shape],
|
||||||
):
|
):
|
||||||
(
|
(query, key_cache, value, block_tables, attn_mask, block_size,
|
||||||
query,
|
seq_lens, query_start_loc, num_kv_heads, num_heads, scale,
|
||||||
key_cache,
|
attn_output, softmax_lse) = param
|
||||||
value_cache,
|
|
||||||
num_kv_heads,
|
|
||||||
num_heads,
|
|
||||||
scale,
|
|
||||||
block_table,
|
|
||||||
seq_lens,
|
|
||||||
output,
|
|
||||||
) = param
|
|
||||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
|
||||||
|
|
||||||
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
|
seq_lens = forward_context.attn_metadata[key].seq_lens_list
|
||||||
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
|
query_start_loc = forward_context.attn_metadata[
|
||||||
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
|
key].query_start_loc_list
|
||||||
# might encounter a bigger workspace, while currently we use max_model_len to
|
|
||||||
# calculate max workspace in capturing. So additional get_workspace is added
|
|
||||||
# here to avoid such bugs.
|
|
||||||
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
|
||||||
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
|
||||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
|
||||||
query=query,
|
|
||||||
key_cache=key_cache,
|
|
||||||
value_cache=value_cache,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
num_heads=num_heads,
|
|
||||||
scale_value=scale,
|
|
||||||
block_table=block_table,
|
|
||||||
context_lens=seq_lens,
|
|
||||||
out=output)
|
|
||||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||||
torch_npu._npu_paged_attention(query=query,
|
torch_npu.npu_fused_infer_attention_score.out(
|
||||||
key_cache=key_cache,
|
query=query,
|
||||||
value_cache=value_cache,
|
key=key_cache,
|
||||||
num_kv_heads=num_kv_heads,
|
value=value,
|
||||||
|
block_table=block_tables,
|
||||||
|
atten_mask=attn_mask,
|
||||||
|
input_layout="TND",
|
||||||
|
block_size=block_size,
|
||||||
|
actual_seq_lengths=query_start_loc,
|
||||||
|
actual_seq_lengths_kv=seq_lens,
|
||||||
|
num_key_value_heads=num_kv_heads,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
scale_value=scale,
|
scale=scale,
|
||||||
block_table=block_table,
|
sparse_mode=3,
|
||||||
context_lens=seq_lens,
|
workspace=graph_params.workspaces.get(runtime_shape),
|
||||||
out=output,
|
out=[attn_output, softmax_lse],
|
||||||
workspace=workspace)
|
)
|
||||||
torch.npu.graph_task_update_end(update_stream)
|
torch.npu.graph_task_update_end(update_stream)
|
||||||
|
|
||||||
event.record(update_stream)
|
event.record(update_stream)
|
||||||
@@ -446,10 +429,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def update_graph_params_workspaces(num_tokens: int, workspace: Any):
|
def update_graph_params_workspaces(num_tokens: int, workspace: int):
|
||||||
global _graph_params
|
global _graph_params
|
||||||
if _graph_params is not None:
|
if _graph_params is not None:
|
||||||
_graph_params.workspaces[num_tokens] = workspace
|
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
|
||||||
|
|
||||||
|
|
||||||
def get_graph_params():
|
def get_graph_params():
|
||||||
|
|||||||
@@ -233,7 +233,8 @@ class NPUPlatform(Platform):
|
|||||||
"vllm.mla_forward"
|
"vllm.mla_forward"
|
||||||
])
|
])
|
||||||
update_aclgraph_sizes(vllm_config)
|
update_aclgraph_sizes(vllm_config)
|
||||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
|
||||||
|
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
|
||||||
logger.info(
|
logger.info(
|
||||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
||||||
"using only ACL Graph mode")
|
"using only ACL Graph mode")
|
||||||
@@ -270,7 +271,8 @@ class NPUPlatform(Platform):
|
|||||||
compilation_config.use_inductor = False
|
compilation_config.use_inductor = False
|
||||||
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
|
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
|
||||||
update_aclgraph_sizes(vllm_config)
|
update_aclgraph_sizes(vllm_config)
|
||||||
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
|
||||||
|
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
|
||||||
logger.info(
|
logger.info(
|
||||||
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
|
||||||
"using only ACL Graph mode")
|
"using only ACL Graph mode")
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import math
|
|||||||
import types
|
import types
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -31,7 +32,6 @@ from vllm.distributed.parallel_state import get_dp_group
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
|
|||||||
@@ -331,6 +331,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.attn_groups: list[list[AttentionGroup]] = []
|
self.attn_groups: list[list[AttentionGroup]] = []
|
||||||
self.encoder_cache: Dict[str, torch.Tensor] = {}
|
self.encoder_cache: Dict[str, torch.Tensor] = {}
|
||||||
self.attn_mask = None
|
self.attn_mask = None
|
||||||
|
self.fia_attn_mask = None
|
||||||
self.attn_state = None
|
self.attn_state = None
|
||||||
self.requests: Dict[str, CachedRequestState] = {}
|
self.requests: Dict[str, CachedRequestState] = {}
|
||||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||||
@@ -1030,6 +1031,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _make_fia_attention_mask(self) -> torch.Tensor:
|
||||||
|
if self.attn_mask_builder is None:
|
||||||
|
raise ValueError("Attn mask builder is None")
|
||||||
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||||
|
|
||||||
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||||
mrope_pos_ptr = 0
|
mrope_pos_ptr = 0
|
||||||
for index, req_id in enumerate(self.input_batch.req_ids):
|
for index, req_id in enumerate(self.input_batch.req_ids):
|
||||||
@@ -1667,6 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
||||||
position=positions_cpu,
|
position=positions_cpu,
|
||||||
attn_state=attn_state)
|
attn_state=attn_state)
|
||||||
|
self.fia_attn_mask = self._make_fia_attention_mask()
|
||||||
self.attn_state = attn_state # type: ignore
|
self.attn_state = attn_state # type: ignore
|
||||||
|
|
||||||
self.with_prefill = with_prefill
|
self.with_prefill = with_prefill
|
||||||
@@ -1899,6 +1906,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
positions=self.positions,
|
positions=self.positions,
|
||||||
attn_mask=self.attn_mask,
|
attn_mask=self.attn_mask,
|
||||||
|
fia_attn_mask=self.fia_attn_mask,
|
||||||
spec_attn_mask=self.spec_attn_mask,
|
spec_attn_mask=self.spec_attn_mask,
|
||||||
attn_state=self.attn_state,
|
attn_state=self.attn_state,
|
||||||
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
||||||
@@ -2756,13 +2764,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
|
|
||||||
self.device).to(torch.int32)
|
self.query_start_loc[1:num_reqs + 1] = torch.Tensor(cu_num_tokens)
|
||||||
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
|
|
||||||
self.query_start_loc_cpu[1:num_reqs +
|
self.query_start_loc_cpu[1:num_reqs +
|
||||||
1] = torch.Tensor(cu_num_tokens)
|
1] = torch.Tensor(cu_num_tokens)
|
||||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||||
|
|
||||||
|
assigned_mask_dim = 2048
|
||||||
|
self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim,
|
||||||
|
assigned_mask_dim),
|
||||||
|
diagonal=1).to(torch.int8).to(
|
||||||
|
self.device)
|
||||||
|
|
||||||
num_computed_tokens_cpu = (
|
num_computed_tokens_cpu = (
|
||||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||||
|
|
||||||
@@ -2805,6 +2818,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
positions=self.positions,
|
positions=self.positions,
|
||||||
attn_mask=self.attn_mask,
|
attn_mask=self.attn_mask,
|
||||||
|
fia_attn_mask=self.fia_attn_mask,
|
||||||
spec_attn_mask=self.spec_attn_mask,
|
spec_attn_mask=self.spec_attn_mask,
|
||||||
attn_state=self.attn_state,
|
attn_state=self.attn_state,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
@@ -3978,10 +3992,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
graph_support = None
|
graph_support = None
|
||||||
if hasattr(builder, 'aclgraph_support'):
|
if hasattr(builder, 'aclgraph_support'):
|
||||||
graph_support = builder.aclgraph_support.value
|
graph_support = builder.aclgraph_support.value
|
||||||
|
builder_aclgraph = builder.aclgraph_support
|
||||||
else:
|
else:
|
||||||
graph_support = builder.cudagraph_support.value
|
graph_support = builder.cudagraph_support.value
|
||||||
|
builder_aclgraph = builder.cudagraph_support
|
||||||
if graph_support < min_ag_support.value:
|
if graph_support < min_ag_support.value:
|
||||||
min_ag_support = builder.aclgraph_support
|
min_ag_support = builder_aclgraph
|
||||||
min_ag_builder_name = builder.__class__.__name__
|
min_ag_builder_name = builder.__class__.__name__
|
||||||
|
|
||||||
# This is an imitation of compilation_config.splitting_ops_contain_attention()
|
# This is an imitation of compilation_config.splitting_ops_contain_attention()
|
||||||
|
|||||||
Reference in New Issue
Block a user