[Refactor]7/N Extract common code to common_cp (#5490)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: Eliminate duplicate code for two file(mla_cp.py attention_cp.py) to common_cp.py. vLLM version: 0.13.0rc3 vLLM main:ad32e3e19cvLLM version: release/v0.13.0 vLLM main:5fbfa8d9ef- vLLM version: v0.13.0 - vLLM main:5326c89803--------- Signed-off-by: wujinyuan1 <wjy9595@qq.com> Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com> Co-authored-by: wujinyuan1 <wjy9595@qq.com>
This commit is contained in:
@@ -5,9 +5,11 @@ import torch
|
|||||||
|
|
||||||
from tests.ut.attention.utils import patch_distributed_groups
|
from tests.ut.attention.utils import patch_distributed_groups
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl
|
from vllm_ascend.attention.attention_v1 import AscendMetadata
|
||||||
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
|
from vllm_ascend.attention.context_parallel.attention_cp import \
|
||||||
AscendMetadataForPrefill)
|
AscendAttentionCPImpl
|
||||||
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||||
|
AscendMetadataForPrefill, AscendPCPMetadata)
|
||||||
|
|
||||||
|
|
||||||
class TestAscendAttentionCPImpl(TestBase):
|
class TestAscendAttentionCPImpl(TestBase):
|
||||||
@@ -82,25 +84,22 @@ class TestAscendAttentionCPImpl(TestBase):
|
|||||||
self.assertEqual(output.shape[1], 4)
|
self.assertEqual(output.shape[1], 4)
|
||||||
self.assertEqual(output.shape[2], 128)
|
self.assertEqual(output.shape[2], 128)
|
||||||
|
|
||||||
|
@patch('torch_npu.npu_attention_update')
|
||||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||||
@patch('vllm_ascend.attention.attention_cp.get_forward_context')
|
@patch(
|
||||||
|
'vllm_ascend.attention.context_parallel.attention_cp.get_forward_context'
|
||||||
|
)
|
||||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||||
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,
|
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,
|
||||||
mock_get_forward_context,
|
mock_get_forward_context,
|
||||||
mock_npu_fused_infer_attention_score):
|
mock_npu_fused_infer_attention_score,
|
||||||
query = torch.randn(2, 4, 128)
|
mock_npu_attention_update):
|
||||||
self.impl.key_cache = torch.randn(100, 128, 1, 128)
|
query = torch.randn(2, 4, 64)
|
||||||
self.impl.value_cache = torch.randn(100, 128, 1, 128)
|
self.impl.key_cache = torch.randn(100, 64, 1, 64)
|
||||||
|
self.impl.value_cache = torch.randn(100, 64, 1, 64)
|
||||||
|
|
||||||
def mock_npu_attention_update(attn_out_lse_list):
|
# Mock output
|
||||||
mock_output = torch.randn(
|
mock_npu_attention_update.return_value = (torch.randn(2 * 4, 64), None)
|
||||||
attn_out_lse_list.shape[0] // mock_pcp.world_size,
|
|
||||||
attn_out_lse_list.shape[1] // mock_dcp.world_size,
|
|
||||||
attn_out_lse_list.shape[2] - 1)
|
|
||||||
return mock_output
|
|
||||||
|
|
||||||
self.impl._npu_attention_update = MagicMock()
|
|
||||||
self.impl._npu_attention_update.side_effect = mock_npu_attention_update
|
|
||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||||
|
|
||||||
@@ -116,12 +115,11 @@ class TestAscendAttentionCPImpl(TestBase):
|
|||||||
attn_metadata.decode_meta = MagicMock()
|
attn_metadata.decode_meta = MagicMock()
|
||||||
attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
|
attn_metadata.decode_meta.batch_seq_mask = torch.tensor(
|
||||||
[1, 0], dtype=torch.bool)
|
[1, 0], dtype=torch.bool)
|
||||||
|
|
||||||
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)
|
output = self.impl._forward_decode_pcp_dcp(query, attn_metadata)
|
||||||
|
|
||||||
self.assertEqual(output.shape[0], 2)
|
self.assertEqual(output.shape[0], 2)
|
||||||
self.assertEqual(output.shape[1], 4)
|
self.assertEqual(output.shape[1], 4)
|
||||||
self.assertEqual(output.shape[2], 128)
|
self.assertEqual(output.shape[2], 64)
|
||||||
|
|
||||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||||
def test_prefill_query_all_gather(self):
|
def test_prefill_query_all_gather(self):
|
||||||
@@ -249,7 +247,7 @@ class TestAscendAttentionCPImpl(TestBase):
|
|||||||
attn_metadata.slot_mapping = torch.randn(2)
|
attn_metadata.slot_mapping = torch.randn(2)
|
||||||
attn_metadata.num_actual_tokens_pcp_padded = num_tokens * self.impl.pcp_size
|
attn_metadata.num_actual_tokens_pcp_padded = num_tokens * self.impl.pcp_size
|
||||||
attn_metadata.prefill = MagicMock()
|
attn_metadata.prefill = MagicMock()
|
||||||
attn_metadata.prefill.pcp_allgather_restore_idx = torch.tensor(
|
attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx = torch.tensor(
|
||||||
[0, 3, 1, 2, 0, 0, 0, 0])
|
[0, 3, 1, 2, 0, 0, 0, 0])
|
||||||
|
|
||||||
key = torch.randn(num_tokens, num_heads, head_size)
|
key = torch.randn(num_tokens, num_heads, head_size)
|
||||||
@@ -336,7 +334,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
attn_metadata.num_actual_tokens = self.q_total_tokens
|
attn_metadata.num_actual_tokens = self.q_total_tokens
|
||||||
|
|
||||||
prefill_metadata = AscendMetadataForPrefill()
|
prefill_metadata = AscendMetadataForPrefill()
|
||||||
pcp_metadata = AscendMetadataForPrefill.AscendPCPMetadata()
|
pcp_metadata = AscendPCPMetadata()
|
||||||
pcp_metadata.attn_mask_seqlens = self.kv_seqlens_mask_cumsum
|
pcp_metadata.attn_mask_seqlens = self.kv_seqlens_mask_cumsum
|
||||||
pcp_metadata.head_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum
|
pcp_metadata.head_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum
|
||||||
pcp_metadata.tail_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum
|
pcp_metadata.tail_attn_nomask_seqlens = self.kv_seqlens_nomask_cumsum
|
||||||
@@ -409,7 +407,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
|
|
||||||
@patch('torch.ops.npu.npu_fused_infer_attention_score')
|
@patch('torch.ops.npu.npu_fused_infer_attention_score')
|
||||||
@patch(
|
@patch(
|
||||||
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._update_out_and_lse'
|
'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._update_out_and_lse'
|
||||||
)
|
)
|
||||||
def test_attention_with_nomask_and_mask_chunk(
|
def test_attention_with_nomask_and_mask_chunk(
|
||||||
self, mock_update_out_and_lse,
|
self, mock_update_out_and_lse,
|
||||||
@@ -457,7 +455,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
|
|
||||||
@patch('torch.ops.npu.npu_fused_infer_attention_score')
|
@patch('torch.ops.npu.npu_fused_infer_attention_score')
|
||||||
@patch(
|
@patch(
|
||||||
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
|
'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
|
||||||
)
|
)
|
||||||
def test_attention_with_nomask_and_mask_nochunk(
|
def test_attention_with_nomask_and_mask_nochunk(
|
||||||
self, mock_npu_attn_out_lse_update,
|
self, mock_npu_attn_out_lse_update,
|
||||||
@@ -505,7 +503,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
self.assertEqual(attn_lse, None)
|
self.assertEqual(attn_lse, None)
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
'vllm_ascend.attention.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
|
'vllm_ascend.attention.context_parallel.attention_cp.AscendAttentionCPImpl._npu_attn_out_lse_update'
|
||||||
)
|
)
|
||||||
def test_update_chunk_attn_out_lse_with_current_attn_out_lse(
|
def test_update_chunk_attn_out_lse_with_current_attn_out_lse(
|
||||||
self, mock_npu_attn_out_lse_update):
|
self, mock_npu_attn_out_lse_update):
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ from tests.ut.attention.utils import patch_distributed_groups
|
|||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.common_cp import CPChunkedContextMetadata
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||||
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
|
CPChunkedContextMetadata, _npu_attention_update, _process_attn_out_lse)
|
||||||
|
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
|
||||||
from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata
|
from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -441,14 +442,14 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
decode_metadata.batch_seq_mask = torch.tensor([True, False],
|
decode_metadata.batch_seq_mask = torch.tensor([True, False],
|
||||||
dtype=torch.bool)
|
dtype=torch.bool)
|
||||||
|
|
||||||
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
|
result = _process_attn_out_lse(attn_output, softmax_lse,
|
||||||
decode_metadata)
|
decode_metadata.batch_seq_mask)
|
||||||
|
|
||||||
self.assertEqual(result.shape[0], B * self.impl.pcp_size)
|
self.assertEqual(result.shape[0], B * self.impl.pcp_size)
|
||||||
self.assertEqual(result.shape[1], N)
|
self.assertEqual(result.shape[1], N)
|
||||||
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
|
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.mla_cp.get_forward_context')
|
@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
|
||||||
@patch("torch_npu.atb.npu_multi_head_latent_attention")
|
@patch("torch_npu.atb.npu_multi_head_latent_attention")
|
||||||
@patch('torch_npu.npu_attention_update')
|
@patch('torch_npu.npu_attention_update')
|
||||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||||
@@ -725,7 +726,15 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
assert torch.allclose(lse, expected_lse)
|
assert torch.allclose(lse, expected_lse)
|
||||||
|
|
||||||
@patch('torch_npu.npu_attention_update')
|
@patch('torch_npu.npu_attention_update')
|
||||||
def test_npu_attention_update_with_dcp_pcp(self,
|
@patch('vllm_ascend.attention.context_parallel.common_cp.get_pcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state._PCP',
|
||||||
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
|
@patch('vllm_ascend.attention.context_parallel.common_cp.get_dcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state._DCP',
|
||||||
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
|
def test_npu_attention_update_with_dcp_pcp(self, mock_dcp,
|
||||||
|
mock_get_dcp_group, mock_pcp,
|
||||||
|
mock_get_pcp_group,
|
||||||
mock_npu_attention_update):
|
mock_npu_attention_update):
|
||||||
NUM_TOKENS = 10 # fixed
|
NUM_TOKENS = 10 # fixed
|
||||||
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (2, 3)]
|
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (2, 3)]
|
||||||
@@ -752,10 +761,19 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
attn_lse_split_cp[0])
|
attn_lse_split_cp[0])
|
||||||
|
|
||||||
mock_npu_attention_update.side_effect = mock_npu_attention_update_effect
|
mock_npu_attention_update.side_effect = mock_npu_attention_update_effect
|
||||||
|
|
||||||
|
mock_pcp_group = MagicMock()
|
||||||
|
mock_pcp_group.world_size = self.impl.pcp_size
|
||||||
|
mock_get_pcp_group.return_value = mock_pcp_group
|
||||||
|
|
||||||
|
mock_dcp.world_size = self.impl.dcp_size
|
||||||
|
mock_dcp_group = MagicMock()
|
||||||
|
# mock_dcp_group.world_size = self.impl.dcp_size
|
||||||
|
mock_get_dcp_group.return_value = mock_dcp_group
|
||||||
attn_out_lse = torch.randn(self.impl.pcp_size * NUM_TOKENS,
|
attn_out_lse = torch.randn(self.impl.pcp_size * NUM_TOKENS,
|
||||||
self.impl.dcp_size * num_heads,
|
self.impl.dcp_size * num_heads,
|
||||||
head_dim)
|
head_dim)
|
||||||
out = self.impl._npu_attention_update(attn_out_lse)
|
out = _npu_attention_update(self.impl.kv_lora_rank, attn_out_lse)
|
||||||
self.impl.dcp_size = 1
|
self.impl.dcp_size = 1
|
||||||
self.impl.pcp_size = 1
|
self.impl.pcp_size = 1
|
||||||
assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank)
|
assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank)
|
||||||
@@ -873,8 +891,8 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
decode_meta = MagicMock()
|
decode_meta = MagicMock()
|
||||||
decode_meta.batch_seq_mask = batch_seq_mask
|
decode_meta.batch_seq_mask = batch_seq_mask
|
||||||
|
|
||||||
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
|
result = _process_attn_out_lse(attn_output, softmax_lse,
|
||||||
decode_meta)
|
batch_seq_mask)
|
||||||
# [PCP * S, DCP * H, D + 1]
|
# [PCP * S, DCP * H, D + 1]
|
||||||
self.assertIsInstance(result, torch.Tensor)
|
self.assertIsInstance(result, torch.Tensor)
|
||||||
assert result.shape == (B * self.impl.pcp_size, H, D + 1)
|
assert result.shape == (B * self.impl.pcp_size, H, D + 1)
|
||||||
|
|||||||
@@ -34,10 +34,10 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
|||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||||
|
AscendMetadataForDecode, AscendMetadataForPrefill)
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
AscendMetadataForDecode,
|
enable_cp, split_decodes_and_prefills,
|
||||||
AscendMetadataForPrefill, enable_cp,
|
|
||||||
split_decodes_and_prefills,
|
|
||||||
using_paged_attention)
|
using_paged_attention)
|
||||||
from vllm_ascend.compilation.acl_graph import (
|
from vllm_ascend.compilation.acl_graph import (
|
||||||
get_draft_graph_params, get_graph_params,
|
get_draft_graph_params, get_graph_params,
|
||||||
@@ -63,7 +63,7 @@ class AscendAttentionBackend(AttentionBackend):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
||||||
if enable_cp():
|
if enable_cp():
|
||||||
from vllm_ascend.attention.attention_cp import \
|
from vllm_ascend.attention.context_parallel.attention_cp import \
|
||||||
AscendAttentionCPImpl
|
AscendAttentionCPImpl
|
||||||
return AscendAttentionCPImpl
|
return AscendAttentionCPImpl
|
||||||
return AscendAttentionBackendImpl
|
return AscendAttentionBackendImpl
|
||||||
@@ -71,7 +71,7 @@ class AscendAttentionBackend(AttentionBackend):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||||
if enable_cp():
|
if enable_cp():
|
||||||
from vllm_ascend.attention.attention_cp import \
|
from vllm_ascend.attention.context_parallel.attention_cp import \
|
||||||
AscendAttentionCPMetadataBuilder
|
AscendAttentionCPMetadataBuilder
|
||||||
return AscendAttentionCPMetadataBuilder
|
return AscendAttentionCPMetadataBuilder
|
||||||
return AscendAttentionMetadataBuilder
|
return AscendAttentionMetadataBuilder
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AscendPCPMetadata:
|
|
||||||
q_head_idx: torch.Tensor = None
|
|
||||||
q_tail_idx: torch.Tensor = None
|
|
||||||
kv_with_q_head_nomask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_head_mask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_tail_mask_idx: torch.Tensor = None
|
|
||||||
attn_mask_seqlens: torch.Tensor = None
|
|
||||||
head_attn_nomask_seqlens: torch.Tensor = None
|
|
||||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
|
||||||
q_full_idx: torch.Tensor = None
|
|
||||||
pcp_prefill_mask: torch.Tensor = None
|
|
||||||
pcp_allgather_restore_idx: Optional[list[int]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CPChunkedContextMetadata:
|
|
||||||
# New for MLA (compared to FlashAttention)
|
|
||||||
# For handling chunked prefill
|
|
||||||
cu_seq_lens: torch.Tensor
|
|
||||||
starts: torch.Tensor
|
|
||||||
seq_tot: list[int]
|
|
||||||
max_seq_lens: list[int]
|
|
||||||
workspace: torch.Tensor
|
|
||||||
chunk_seq_lens: torch.Tensor
|
|
||||||
chunk_seq_lens_npu: torch.Tensor
|
|
||||||
# for mla DCP & PCP
|
|
||||||
padded_chunk_seq_lens_npu: torch.Tensor = None
|
|
||||||
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
|
|
||||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
|
||||||
padded_local_cu_seq_lens: torch.Tensor = None
|
|
||||||
cu_seq_lens_lst: Optional[list[list[int]]] = None
|
|
||||||
chunk_size: Optional[int] = None
|
|
||||||
0
vllm_ascend/attention/context_parallel/__init__.py
Normal file
0
vllm_ascend/attention/context_parallel/__init__.py
Normal file
@@ -33,9 +33,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl,
|
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackendImpl,
|
||||||
AscendAttentionMetadataBuilder,
|
AscendAttentionMetadataBuilder,
|
||||||
AscendMetadata)
|
AscendMetadata)
|
||||||
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||||
|
AscendMetadataForDecode, AscendMetadataForPrefill, AscendPCPMetadata,
|
||||||
|
_npu_attention_update, _process_attn_out_lse)
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
AscendMetadataForDecode,
|
|
||||||
AscendMetadataForPrefill,
|
|
||||||
filter_chunked_req_indices,
|
filter_chunked_req_indices,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||||
@@ -196,7 +197,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
tail_attn_nomask_seqlens = torch.cumsum(
|
tail_attn_nomask_seqlens = torch.cumsum(
|
||||||
tail_attn_nomask_seqlens[1], dim=0).tolist()
|
tail_attn_nomask_seqlens[1], dim=0).tolist()
|
||||||
|
|
||||||
pcp_metadata = AscendMetadataForPrefill.AscendPCPMetadata(
|
pcp_metadata = AscendPCPMetadata(
|
||||||
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
|
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
|
||||||
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
|
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
|
||||||
kv_with_q_head_nomask_idx=common_long_seq_metadata.
|
kv_with_q_head_nomask_idx=common_long_seq_metadata.
|
||||||
@@ -211,13 +212,12 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
|
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||||
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask)
|
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
|
||||||
|
pcp_allgather_restore_idx=common_long_seq_metadata.
|
||||||
|
pcp_allgather_restore_idx)
|
||||||
|
|
||||||
prefill_metadata = AscendMetadataForPrefill(
|
prefill_metadata = AscendMetadataForPrefill(
|
||||||
pcp_metadata=pcp_metadata,
|
pcp_metadata=pcp_metadata,
|
||||||
pcp_allgather_restore_idx=common_long_seq_metadata.
|
|
||||||
pcp_allgather_restore_idx
|
|
||||||
if common_long_seq_metadata is not None else None,
|
|
||||||
chunked_context=chunked_context_metadata,
|
chunked_context=chunked_context_metadata,
|
||||||
block_tables=block_table[num_decodes:],
|
block_tables=block_table[num_decodes:],
|
||||||
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0))
|
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0))
|
||||||
@@ -460,39 +460,6 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
|
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
|
||||||
return attn_out, attn_lse
|
return attn_out, attn_lse
|
||||||
|
|
||||||
def _npu_attention_update(self,
|
|
||||||
attn_out_lse: torch.Tensor) -> torch.Tensor:
|
|
||||||
B_total, H_total, D_plus_1 = attn_out_lse.shape
|
|
||||||
S = B_total // self.pcp_size
|
|
||||||
H = H_total // self.dcp_size
|
|
||||||
D = self.head_size
|
|
||||||
update_type = 0
|
|
||||||
assert D_plus_1 == D + 1
|
|
||||||
# [PCP, S, DCP, H, D+1]
|
|
||||||
x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1)
|
|
||||||
# [PCP, DCP, S, H, D+1]
|
|
||||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
|
||||||
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
|
|
||||||
x = x.view(-1, S, H, D_plus_1)
|
|
||||||
# Split out lse
|
|
||||||
# [N, S, H, D], [N, S, H, 1]
|
|
||||||
out_flat, lse_flat = torch.split(x, [D, 1], dim=-1)
|
|
||||||
# out: [N, S, H, D] -> [N, S*H, D]
|
|
||||||
# lse: [N, S, H, 1] -> [N, S*H]
|
|
||||||
out_flat = out_flat.flatten(1, 2)
|
|
||||||
lse_flat = lse_flat.squeeze(-1).flatten(1)
|
|
||||||
# unbind to list
|
|
||||||
# [S*H, D]
|
|
||||||
out_list = out_flat.unbind(0)
|
|
||||||
# [S*H]
|
|
||||||
lse_list = lse_flat.unbind(0)
|
|
||||||
|
|
||||||
attn_out, attn_lse = torch_npu.npu_attention_update(
|
|
||||||
lse_list, out_list, update_type)
|
|
||||||
attn_out = attn_out.view(S, H, D)
|
|
||||||
|
|
||||||
return attn_out
|
|
||||||
|
|
||||||
def _forward_decode_pcp_dcp(self, query: torch.Tensor,
|
def _forward_decode_pcp_dcp(self, query: torch.Tensor,
|
||||||
attn_metadata: AscendMetadata) -> torch.Tensor:
|
attn_metadata: AscendMetadata) -> torch.Tensor:
|
||||||
assert self.key_cache is not None
|
assert self.key_cache is not None
|
||||||
@@ -580,33 +547,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
else:
|
else:
|
||||||
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
||||||
query, k_nope, value, **common_kwargs)
|
query, k_nope, value, **common_kwargs)
|
||||||
|
attn_out_lse = _process_attn_out_lse(
|
||||||
out_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
|
attn_out, attn_lse, attn_metadata.decode_meta.batch_seq_mask)
|
||||||
None].expand_as(
|
attn_out = _npu_attention_update(self.head_size, attn_out_lse)
|
||||||
attn_out)
|
|
||||||
attn_out = torch.where(out_mask, 0, attn_out)
|
|
||||||
|
|
||||||
lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
|
|
||||||
None].expand_as(
|
|
||||||
attn_lse)
|
|
||||||
attn_lse = torch.where(lse_mask, -torch.inf, attn_lse)
|
|
||||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
|
||||||
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
|
|
||||||
if self.dcp_size > 1:
|
|
||||||
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
|
|
||||||
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
|
|
||||||
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
|
|
||||||
dist.all_to_all_single(attn_out_lse_all2all,
|
|
||||||
attn_out_lse,
|
|
||||||
group=self.dcp_group)
|
|
||||||
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
|
|
||||||
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
# AllGather out&lse within CP group
|
|
||||||
attn_out_lse = get_pcp_group().all_gather(
|
|
||||||
attn_out_lse.contiguous(), dim=0)
|
|
||||||
|
|
||||||
attn_out = self._npu_attention_update(attn_out_lse)
|
|
||||||
return attn_out
|
return attn_out
|
||||||
|
|
||||||
def _update_out_and_lse(self, out_list: torch.Tensor,
|
def _update_out_and_lse(self, out_list: torch.Tensor,
|
||||||
@@ -780,7 +723,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||||
all_kv = get_pcp_group().all_gather(
|
all_kv = get_pcp_group().all_gather(
|
||||||
kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
|
kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
|
||||||
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
|
assert attn_metadata.prefill is not None
|
||||||
|
assert attn_metadata.prefill.pcp_metadata is not None
|
||||||
|
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
|
||||||
all_kv = torch.index_select(all_kv, 0,
|
all_kv = torch.index_select(all_kv, 0,
|
||||||
pcp_allgather_restore_idx)
|
pcp_allgather_restore_idx)
|
||||||
key, value = all_kv.split([self.head_size, self.head_size],
|
key, value = all_kv.split([self.head_size, self.head_size],
|
||||||
137
vllm_ascend/attention/context_parallel/common_cp.py
Normal file
137
vllm_ascend/attention/context_parallel/common_cp.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch_npu
|
||||||
|
from vllm.distributed import (get_dcp_group,
|
||||||
|
get_decode_context_model_parallel_world_size,
|
||||||
|
get_pcp_group)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendPCPMetadata:
|
||||||
|
q_head_idx: torch.Tensor = None
|
||||||
|
q_tail_idx: torch.Tensor = None
|
||||||
|
kv_with_q_head_nomask_idx: torch.Tensor = None
|
||||||
|
kv_with_q_head_mask_idx: torch.Tensor = None
|
||||||
|
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
||||||
|
kv_with_q_tail_mask_idx: torch.Tensor = None
|
||||||
|
attn_mask_seqlens: torch.Tensor = None
|
||||||
|
head_attn_nomask_seqlens: torch.Tensor = None
|
||||||
|
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||||
|
q_full_idx: torch.Tensor = None
|
||||||
|
pcp_prefill_mask: torch.Tensor = None
|
||||||
|
pcp_allgather_restore_idx: Optional[list[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CPChunkedContextMetadata:
|
||||||
|
# New for MLA (compared to FlashAttention)
|
||||||
|
# For handling chunked prefill
|
||||||
|
cu_seq_lens: torch.Tensor
|
||||||
|
starts: torch.Tensor
|
||||||
|
seq_tot: list[int]
|
||||||
|
max_seq_lens: list[int]
|
||||||
|
workspace: torch.Tensor
|
||||||
|
chunk_seq_lens: torch.Tensor
|
||||||
|
chunk_seq_lens_npu: torch.Tensor
|
||||||
|
# for mla DCP & PCP
|
||||||
|
padded_chunk_seq_lens_npu: torch.Tensor = None
|
||||||
|
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
|
||||||
|
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||||
|
padded_local_cu_seq_lens: torch.Tensor = None
|
||||||
|
cu_seq_lens_lst: Optional[list[list[int]]] = None
|
||||||
|
chunk_size: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendMetadataForPrefill:
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkedContextMetadata:
|
||||||
|
actual_chunk_seq_lengths: torch.Tensor
|
||||||
|
actual_seq_lengths_kv: torch.Tensor
|
||||||
|
starts: torch.Tensor
|
||||||
|
chunk_seq_mask_filtered_indices: torch.Tensor
|
||||||
|
chunked_req_mask: Optional[list[bool]] = None
|
||||||
|
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||||
|
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||||
|
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
||||||
|
batch_chunk_seq_mask: Optional[list[bool]] = None
|
||||||
|
local_total_toks: Optional[int] = None
|
||||||
|
|
||||||
|
""" Prefill Specific Metadata for Ascend"""
|
||||||
|
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||||
|
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||||
|
block_tables: torch.Tensor = None
|
||||||
|
actual_seq_lengths_q: torch.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendMetadataForDecode:
|
||||||
|
""" Decode Specific Metadata for Ascend"""
|
||||||
|
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||||
|
batch_seq_mask: torch.Tensor = None
|
||||||
|
block_tables: torch.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
|
def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor,
|
||||||
|
batch_seq_mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
pcp_size = get_pcp_group().world_size
|
||||||
|
dcp_size = get_decode_context_model_parallel_world_size()
|
||||||
|
dcp_group = get_dcp_group().device_group if dcp_size > 1 else None
|
||||||
|
out_mask = batch_seq_mask[:, None, None].expand_as(attn_output)
|
||||||
|
attn_output = torch.where(out_mask, 0, attn_output)
|
||||||
|
lse_mask = batch_seq_mask[:, None, None].expand_as(softmax_lse)
|
||||||
|
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
|
||||||
|
softmax_lse = softmax_lse.to(torch.float32)
|
||||||
|
attn_output = attn_output.to(torch.float32)
|
||||||
|
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||||
|
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
|
||||||
|
if dcp_size > 1:
|
||||||
|
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
|
||||||
|
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
|
||||||
|
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
|
||||||
|
dist.all_to_all_single(attn_out_lse_all2all,
|
||||||
|
attn_out_lse,
|
||||||
|
group=dcp_group)
|
||||||
|
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
|
||||||
|
|
||||||
|
if pcp_size > 1:
|
||||||
|
# AllGather out&lse within CP group
|
||||||
|
attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(),
|
||||||
|
dim=0)
|
||||||
|
|
||||||
|
return attn_out_lse
|
||||||
|
|
||||||
|
|
||||||
|
def _npu_attention_update(head_size,
|
||||||
|
attn_out_lse: torch.Tensor) -> torch.Tensor:
|
||||||
|
pcp_size = get_pcp_group().world_size
|
||||||
|
dcp_size = get_decode_context_model_parallel_world_size()
|
||||||
|
# [PCP * S, DCP * H, D+1]
|
||||||
|
B_total, H_total, D_plus_1 = attn_out_lse.shape
|
||||||
|
S = B_total // pcp_size
|
||||||
|
H = H_total // dcp_size
|
||||||
|
D = head_size
|
||||||
|
assert D_plus_1 == D + 1
|
||||||
|
# [PCP, S, DCP, H, D+1]
|
||||||
|
x = attn_out_lse.view(pcp_size, S, dcp_size, H, D_plus_1)
|
||||||
|
# [PCP, DCP, S, H, D+1]
|
||||||
|
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
|
||||||
|
x = x.view(-1, S, H, D_plus_1)
|
||||||
|
# Split out lse
|
||||||
|
out_flat, lse_flat = torch.split(x, [D, 1],
|
||||||
|
dim=-1) # [N, S, H, D], [N, S, H, 1]
|
||||||
|
# out: [N, S, H, D] -> [N, S*H, D]
|
||||||
|
# lse: [N, S, H, 1] -> [N, S*H]
|
||||||
|
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]
|
||||||
|
lse_flat = lse_flat.flatten(1, -1) # [N, S*H]
|
||||||
|
# unbind to list
|
||||||
|
out_list = out_flat.unbind(0) # [S*H, D]
|
||||||
|
lse_list = lse_flat.unbind(0) # [S*H]
|
||||||
|
attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
|
||||||
|
attn_out = attn_out.view(-1, H, D)
|
||||||
|
return attn_out
|
||||||
@@ -2,7 +2,6 @@ from typing import Optional, Tuple, TypeVar
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (get_dcp_group,
|
from vllm.distributed import (get_dcp_group,
|
||||||
@@ -15,17 +14,17 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
|||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||||
|
|
||||||
# isort: off
|
# isort: off
|
||||||
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
from vllm_ascend.attention.mla_v1 import (
|
||||||
AscendMLAImpl, AscendMLAMetadata,
|
AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata,
|
||||||
AscendMLAMetadataBuilder,
|
AscendMLAMetadataBuilder, AscendMLAPrefillMetadata,
|
||||||
AscendMLAPrefillMetadata,
|
DecodeMLAPreprocessResult, PrefillMLAPreprocessResult,
|
||||||
DecodeMLAPreprocessResult,
|
BUILD_METADATA_STEP_PREFILL)
|
||||||
PrefillMLAPreprocessResult)
|
|
||||||
#isort: on
|
#isort: on
|
||||||
|
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata)
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata)
|
||||||
from vllm_ascend.attention.common_cp import (AscendPCPMetadata,
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||||
CPChunkedContextMetadata)
|
AscendPCPMetadata, CPChunkedContextMetadata, _process_attn_out_lse,
|
||||||
|
_npu_attention_update)
|
||||||
from vllm_ascend.compilation.acl_graph import (get_draft_graph_params,
|
from vllm_ascend.compilation.acl_graph import (get_draft_graph_params,
|
||||||
get_graph_params,
|
get_graph_params,
|
||||||
update_graph_params_workspaces)
|
update_graph_params_workspaces)
|
||||||
@@ -89,6 +88,8 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
|||||||
if long_seq_metadata is None:
|
if long_seq_metadata is None:
|
||||||
raise AssertionError("long_seq_metadata should not be None.")
|
raise AssertionError("long_seq_metadata should not be None.")
|
||||||
|
|
||||||
|
# In dcp only spec decode graph padding case,
|
||||||
|
# num_actual_tokens_pcp_padded may be less than num_actual_tokens
|
||||||
self.num_actual_tokens = max(
|
self.num_actual_tokens = max(
|
||||||
long_seq_metadata.num_actual_tokens_pcp_padded,
|
long_seq_metadata.num_actual_tokens_pcp_padded,
|
||||||
common_attn_metadata.num_actual_tokens)
|
common_attn_metadata.num_actual_tokens)
|
||||||
@@ -187,21 +188,17 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
|||||||
)
|
)
|
||||||
return chunked_metadata
|
return chunked_metadata
|
||||||
|
|
||||||
def set_prefill_block_table(
|
def get_block_table_size(
|
||||||
self,
|
self, common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
build_metadata_step: int):
|
||||||
):
|
|
||||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
|
||||||
# to avoid irregular spec_attn_mask shape
|
|
||||||
self.num_decodes_flatten = self.query_lens[:self.num_decodes].sum(
|
self.num_decodes_flatten = self.query_lens[:self.num_decodes].sum(
|
||||||
).item()
|
).item()
|
||||||
self.block_table = common_attn_metadata.block_table_tensor[:self.
|
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
|
||||||
num_decodes_flatten
|
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||||
+ self.
|
# to avoid irregular spec_attn_mask shape
|
||||||
num_prefills]
|
return self.num_decodes_flatten + self.num_prefills
|
||||||
|
else:
|
||||||
def set_decode_block_table(self):
|
return self.num_decodes_flatten
|
||||||
self.block_table = self.block_table[:self.num_decodes_flatten, ...]
|
|
||||||
|
|
||||||
def build_prefill_metadata(
|
def build_prefill_metadata(
|
||||||
self,
|
self,
|
||||||
@@ -637,39 +634,11 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
lse=softmax_lse)
|
lse=softmax_lse)
|
||||||
|
|
||||||
# Update out&lse
|
# Update out&lse
|
||||||
attn_out_lse = self._process_attn_out_lse(attn_output, softmax_lse,
|
attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse,
|
||||||
decode_meta)
|
decode_meta.batch_seq_mask)
|
||||||
attn_output = self._npu_attention_update(attn_out_lse)
|
attn_output = _npu_attention_update(self.kv_lora_rank, attn_out_lse)
|
||||||
return self._v_up_proj(attn_output)
|
return self._v_up_proj(attn_output)
|
||||||
|
|
||||||
def _npu_attention_update(self,
|
|
||||||
attn_out_lse: torch.Tensor) -> torch.Tensor:
|
|
||||||
# [PCP * S, DCP * H, D+1]
|
|
||||||
B_total, H_total, D_plus_1 = attn_out_lse.shape
|
|
||||||
S = B_total // self.pcp_size
|
|
||||||
H = H_total // self.dcp_size
|
|
||||||
D = self.kv_lora_rank
|
|
||||||
assert D_plus_1 == D + 1
|
|
||||||
# [PCP, S, DCP, H, D+1]
|
|
||||||
x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1)
|
|
||||||
# [PCP, DCP, S, H, D+1]
|
|
||||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
|
||||||
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
|
|
||||||
x = x.view(-1, S, H, D_plus_1)
|
|
||||||
# Split out lse
|
|
||||||
out_flat, lse_flat = torch.split(x, [D, 1],
|
|
||||||
dim=-1) # [N, S, H, D], [N, S, H, 1]
|
|
||||||
# out: [N, S, H, D] -> [N, S*H, D]
|
|
||||||
# lse: [N, S, H, 1] -> [N, S*H]
|
|
||||||
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]
|
|
||||||
lse_flat = lse_flat.flatten(1, -1) # [N, S*H]
|
|
||||||
# unbind to list
|
|
||||||
out_list = out_flat.unbind(0) # [S*H, D]
|
|
||||||
lse_list = lse_flat.unbind(0) # [S*H]
|
|
||||||
attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
|
|
||||||
attn_out = attn_out.view(-1, H, D)
|
|
||||||
return attn_out
|
|
||||||
|
|
||||||
def _out_lse_reshape(self, attn_out: torch.Tensor,
|
def _out_lse_reshape(self, attn_out: torch.Tensor,
|
||||||
attn_lse: torch.Tensor) -> torch.Tensor:
|
attn_lse: torch.Tensor) -> torch.Tensor:
|
||||||
attn_out = attn_out.contiguous().view(
|
attn_out = attn_out.contiguous().view(
|
||||||
@@ -678,39 +647,6 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
|
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
|
||||||
return attn_out, attn_lse
|
return attn_out, attn_lse
|
||||||
|
|
||||||
def _process_attn_out_lse(
|
|
||||||
self,
|
|
||||||
attn_output: torch.Tensor,
|
|
||||||
softmax_lse: torch.Tensor,
|
|
||||||
decode_meta: AscendMLADecodeMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
out_mask = decode_meta.batch_seq_mask[:, None,
|
|
||||||
None].expand_as(attn_output)
|
|
||||||
attn_output = torch.where(out_mask, 0, attn_output)
|
|
||||||
lse_mask = decode_meta.batch_seq_mask[:, None,
|
|
||||||
None].expand_as(softmax_lse)
|
|
||||||
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
|
|
||||||
|
|
||||||
softmax_lse = softmax_lse.to(torch.float32)
|
|
||||||
attn_output = attn_output.to(torch.float32)
|
|
||||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
|
||||||
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
|
|
||||||
if self.dcp_size > 1:
|
|
||||||
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
|
|
||||||
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
|
|
||||||
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
|
|
||||||
dist.all_to_all_single(attn_out_lse_all2all,
|
|
||||||
attn_out_lse,
|
|
||||||
group=self.dcp_group)
|
|
||||||
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
|
|
||||||
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
# AllGather out&lse within CP group
|
|
||||||
attn_out_lse = get_pcp_group().all_gather(
|
|
||||||
attn_out_lse.contiguous(), dim=0)
|
|
||||||
|
|
||||||
return attn_out_lse
|
|
||||||
|
|
||||||
def _reorg_kvcache(
|
def _reorg_kvcache(
|
||||||
self,
|
self,
|
||||||
kv_c_normed: torch.Tensor,
|
kv_c_normed: torch.Tensor,
|
||||||
@@ -19,8 +19,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
|||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.common_cp import (AscendPCPMetadata,
|
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||||
CPChunkedContextMetadata)
|
AscendPCPMetadata, CPChunkedContextMetadata)
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
enable_cp,
|
enable_cp,
|
||||||
maybe_save_kv_layer_to_connector,
|
maybe_save_kv_layer_to_connector,
|
||||||
@@ -46,6 +46,8 @@ if TYPE_CHECKING:
|
|||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||||
|
BUILD_METADATA_STEP_PREFILL = 0
|
||||||
|
BUILD_METADATA_STEP_DECODE = 1
|
||||||
|
|
||||||
|
|
||||||
class AscendMLABackend(AttentionBackend):
|
class AscendMLABackend(AttentionBackend):
|
||||||
@@ -61,7 +63,8 @@ class AscendMLABackend(AttentionBackend):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls():
|
def get_builder_cls():
|
||||||
if enable_cp():
|
if enable_cp():
|
||||||
from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder
|
from vllm_ascend.attention.context_parallel.mla_cp import \
|
||||||
|
AscendMlaCPMetadataBuilder
|
||||||
return AscendMlaCPMetadataBuilder
|
return AscendMlaCPMetadataBuilder
|
||||||
return AscendMLAMetadataBuilder
|
return AscendMLAMetadataBuilder
|
||||||
|
|
||||||
@@ -73,7 +76,8 @@ class AscendMLABackend(AttentionBackend):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["MLAAttentionImpl"]:
|
def get_impl_cls() -> Type["MLAAttentionImpl"]:
|
||||||
if enable_cp():
|
if enable_cp():
|
||||||
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
|
from vllm_ascend.attention.context_parallel.mla_cp import \
|
||||||
|
AscendMlaCPImpl
|
||||||
return AscendMlaCPImpl
|
return AscendMlaCPImpl
|
||||||
return AscendMLAImpl
|
return AscendMLAImpl
|
||||||
|
|
||||||
@@ -418,7 +422,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
self.query_lens = query_seq_lens_cpu[:num_reqs]
|
self.query_lens = query_seq_lens_cpu[:num_reqs]
|
||||||
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||||
|
|
||||||
self.set_prefill_block_table(common_attn_metadata)
|
self.graph_pad_size = common_attn_metadata.graph_pad_size
|
||||||
|
block_table_size = self.get_block_table_size(
|
||||||
|
common_attn_metadata, BUILD_METADATA_STEP_PREFILL)
|
||||||
|
self.block_table = common_attn_metadata.block_table_tensor[:
|
||||||
|
block_table_size]
|
||||||
|
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
if self.num_prefills > 0:
|
if self.num_prefills > 0:
|
||||||
@@ -499,23 +507,16 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
workspace=self.chunked_prefill_workspace,
|
workspace=self.chunked_prefill_workspace,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_prefill_block_table(
|
def get_block_table_size(
|
||||||
self,
|
self, common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
build_metadata_step: int):
|
||||||
):
|
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
|
||||||
# If graph_pad_size > -1, mean is running in fullgraph mode.
|
# If graph_pad_size > -1, mean is running in fullgraph mode.
|
||||||
self.graph_pad_size = common_attn_metadata.graph_pad_size
|
|
||||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||||
if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch:
|
if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch:
|
||||||
self.block_table = (
|
return self.graph_pad_size
|
||||||
common_attn_metadata.block_table_tensor[:self.graph_pad_size])
|
return common_attn_metadata.num_reqs
|
||||||
else:
|
return self.num_decodes
|
||||||
self.block_table = (
|
|
||||||
common_attn_metadata.block_table_tensor[:common_attn_metadata.
|
|
||||||
num_reqs])
|
|
||||||
|
|
||||||
def set_decode_block_table(self):
|
|
||||||
self.block_table = self.block_table[:self.num_decodes, ...]
|
|
||||||
|
|
||||||
def build_prefill_metadata(
|
def build_prefill_metadata(
|
||||||
self,
|
self,
|
||||||
@@ -574,7 +575,9 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
|||||||
self.seq_lens = self.seq_lens[:self.num_decodes]
|
self.seq_lens = self.seq_lens[:self.num_decodes]
|
||||||
input_positions = input_positions[:self.num_decode_tokens]
|
input_positions = input_positions[:self.num_decode_tokens]
|
||||||
|
|
||||||
self.set_decode_block_table()
|
block_table_size = self.get_block_table_size(
|
||||||
|
common_attn_metadata, BUILD_METADATA_STEP_DECODE)
|
||||||
|
self.block_table = self.block_table[:block_table_size]
|
||||||
|
|
||||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||||
|
|||||||
@@ -35,52 +35,6 @@ def enable_cp():
|
|||||||
or prefill_config.decode_context_parallel_size > 1
|
or prefill_config.decode_context_parallel_size > 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AscendMetadataForPrefill:
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AscendPCPMetadata:
|
|
||||||
q_head_idx: torch.Tensor = None
|
|
||||||
q_tail_idx: torch.Tensor = None
|
|
||||||
kv_with_q_head_nomask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_head_mask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_tail_mask_idx: torch.Tensor = None
|
|
||||||
attn_mask_seqlens: torch.Tensor = None
|
|
||||||
head_attn_nomask_seqlens: torch.Tensor = None
|
|
||||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
|
||||||
q_full_idx: torch.Tensor = None
|
|
||||||
pcp_prefill_mask: torch.Tensor = None
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ChunkedContextMetadata:
|
|
||||||
actual_chunk_seq_lengths: torch.Tensor
|
|
||||||
actual_seq_lengths_kv: torch.Tensor
|
|
||||||
starts: torch.Tensor
|
|
||||||
chunk_seq_mask_filtered_indices: torch.Tensor
|
|
||||||
chunked_req_mask: Optional[list[bool]] = None
|
|
||||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
|
||||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
|
||||||
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
|
||||||
batch_chunk_seq_mask: Optional[list[bool]] = None
|
|
||||||
local_total_toks: Optional[int] = None
|
|
||||||
|
|
||||||
""" Prefill Specific Metadata for Ascend"""
|
|
||||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
|
||||||
pcp_allgather_restore_idx: Optional[List[int]] = None
|
|
||||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
|
||||||
block_tables: torch.Tensor = None
|
|
||||||
actual_seq_lengths_q: torch.Tensor = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AscendMetadataForDecode:
|
|
||||||
""" Decode Specific Metadata for Ascend"""
|
|
||||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
|
||||||
batch_seq_mask: torch.Tensor = None
|
|
||||||
block_tables: torch.Tensor = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
# class AscendCommonLongSequenceMetadata:
|
# class AscendCommonLongSequenceMetadata:
|
||||||
class AscendPrefillContextParallelMetadata:
|
class AscendPrefillContextParallelMetadata:
|
||||||
|
|||||||
Reference in New Issue
Block a user