[Refactor]5/N Extract common code of mla_v1.py & extract mla_cp (#5097)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
The functions related to Cp differ significantly from those of normal
MLA-Attention, but the coupling is quite severe.

Steps:
1)Extract common code AscendMLAMetadataBuilder.build to 4 functions: 
build_prefill_metadata, build_decode_metadata,build_cp_metadata,
build_chunked_metadata

todo:
1)refactor function _compute_prefill_context;
2)refactor function _mla_preprocess,_mla_decode_preprocess
3)Extract public data and processing functions from the attention_cp.py
and mla_cp.py files to the common_cp file.

vLLM version: 0.13.0rc3
vLLM main:
ad32e3e19c

- vLLM version: 0.13.0rc3
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wujinyuan1 <wjy9595@qq.com>
Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com>
Co-authored-by: wujinyuan1 <wjy9595@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
wujinyuan1
2025-12-24 10:25:19 +08:00
committed by GitHub
parent 2a2d527e96
commit 7ff1db4b84
6 changed files with 545 additions and 718 deletions

View File

@@ -6,8 +6,9 @@ from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import TestBase
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.common_cp import CPChunkedContextMetadata
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata
from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata
def get_pcp_split_info(pcp_rank, pcp_size, seq_lens):
@@ -127,7 +128,7 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
chunked_context_metadata = CPChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
starts=local_chunk_starts.to(non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
@@ -144,16 +145,15 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
chunk_size=padded_local_max_context_chunk_across_ranks,
)
else:
chunked_context_metadata = (
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
starts=chunk_starts.to(non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens,
workspace=None,
))
chunked_context_metadata = (ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
starts=chunk_starts.to(non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens,
workspace=None,
))
return chunked_context_metadata

View File

@@ -14,7 +14,8 @@ from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
AscendMLADecodeMetadata,
AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder,
AscendMLAPrefillMetadata)
AscendMLAPrefillMetadata,
ChunkedContextMetadata)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
@@ -76,27 +77,15 @@ class TestAscendMLAPrefillMetadata(TestBase):
max_seq_lens = [2, 2]
workspace = torch.randn(2, 4)
chunk_seq_lens = torch.tensor([2, 2])
padded_chunk_seq_lens_npu = torch.tensor([2, 2])
padded_local_chunk_seq_lens = [[2], [2]]
local_context_lens_allranks = [[1, 1], [1, 1]]
padded_local_cu_seq_lens = torch.tensor([0, 2, 4])
cu_seq_lens_lst = [[0, 2], [2, 4]]
chunk_size = 2
chunked_context = AscendMLAPrefillMetadata.ChunkedContextMetadata(
chunked_context = ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens,
starts=starts,
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens,
padded_chunk_seq_lens_npu=padded_chunk_seq_lens_npu,
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens,
local_context_lens_allranks=local_context_lens_allranks,
padded_local_cu_seq_lens=padded_local_cu_seq_lens,
cu_seq_lens_lst=cu_seq_lens_lst,
chunk_size=chunk_size)
chunk_seq_lens_npu=chunk_seq_lens)
metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -119,17 +108,6 @@ class TestAscendMLAPrefillMetadata(TestBase):
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
chunk_seq_lens)
self.assertIs(metadata.chunked_context.padded_chunk_seq_lens_npu,
padded_chunk_seq_lens_npu)
self.assertEqual(metadata.chunked_context.padded_local_chunk_seq_lens,
padded_local_chunk_seq_lens)
self.assertEqual(metadata.chunked_context.local_context_lens_allranks,
local_context_lens_allranks)
self.assertIs(metadata.chunked_context.padded_local_cu_seq_lens,
padded_local_cu_seq_lens)
self.assertEqual(metadata.chunked_context.cu_seq_lens_lst,
cu_seq_lens_lst)
self.assertEqual(metadata.chunked_context.chunk_size, chunk_size)
class TestAscendMLADecodeMetadata(TestBase):
@@ -218,11 +196,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
mock_dcp, mock_get_dcp_group,
mock_pcp, mock_get_pcp_group):
def test_ascend_mla_metadata_builder_default(self, mock_dcp,
mock_get_dcp_group, mock_pcp,
mock_get_pcp_group):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
@@ -262,8 +238,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.enable_chunked_prefill)
self.assertEqual(builder.dcp_size, mock_dcp.world_size)
self.assertEqual(builder.pcp_size, mock_pcp.world_size)
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
@@ -271,10 +245,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size,
mock_dcp,
def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp,
mock_get_dcp_group,
mock_pcp,
mock_get_pcp_group):
@@ -324,11 +295,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_ascend_mla_metadata_builder_build_full_graph(
self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
mock_get_pcp_group):
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
@@ -387,10 +355,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
mock_get_dcp_group, mock_pcp, mock_get_pcp_group):
def test_reorder_batch(self, mock_dcp, mock_get_dcp_group, mock_pcp,
mock_get_pcp_group):
ascend_config = MagicMock()
mock_vllm_config = MagicMock()
@@ -448,10 +414,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_get_dcp_size,
mock_dcp,
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_dcp,
mock_get_dcp_group,
mock_pcp,
mock_get_pcp_group):
@@ -496,11 +459,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_get_dcp_size,
mock_dcp, mock_get_dcp_group,
mock_pcp,
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_dcp,
mock_get_dcp_group, mock_pcp,
mock_get_pcp_group):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
@@ -566,17 +526,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.kv_cache_spec.head_size = 128
self.kv_cache_spec.num_heads = 32
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
@patch(
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
@patch("torch.Tensor.npu", new=lambda self: self)
@patch("torch.npu.is_available")
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
mock_zeros, mock_get_ascend_config,
mock_dcp_world_size,
mock_zeros, mock_dcp_world_size,
mock_get_pcp_group):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
@@ -633,17 +590,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
@patch(
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
@patch("torch.Tensor.npu", new=lambda self: self)
@patch("torch.npu.is_available")
def test_build_chunked_prefix_metadata(self, mock_npu_available,
mock_zeros, mock_get_ascend_config,
mock_dcp_world_size,
mock_zeros, mock_dcp_world_size,
mock_get_pcp_group):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
@@ -701,13 +655,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
@patch(
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_build_decode_only_metadata(self, mock_get_ascend_config,
mock_dcp_world_size,
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_decode_only_metadata(self, mock_dcp_world_size,
mock_get_pcp_group):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
@@ -757,13 +708,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
@patch(
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config,
mock_dcp_world_size,
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
mock_get_pcp_group):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
@@ -814,13 +762,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
@patch(
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_build_for_graph_capture_prefill(self, mock_get_ascend_config,
mock_dcp_world_size,
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
mock_get_pcp_group):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
@@ -868,16 +813,10 @@ class TestAscendMLAImpl(TestBase):
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
return_value=2)
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp):
def setUp(self, get_current_vllm_config, mock_tp, mock_dcp, mock_pcp):
mock_tp.world_size = 2
mock_tp.rank_in_group = MagicMock()
mock_tp.device_group = MagicMock()

View File

@@ -0,0 +1,40 @@
from dataclasses import dataclass
from typing import Optional
import torch
@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None
@dataclass
class CPChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
# for mla DCP & PCP
padded_chunk_seq_lens_npu: torch.Tensor = None
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
padded_local_cu_seq_lens: torch.Tensor = None
cu_seq_lens_lst: Optional[list[list[int]]] = None
chunk_size: Optional[int] = None

View File

@@ -5,31 +5,32 @@ import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_pcp_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils.math_utils import cdiv, round_down
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec
# isort: off
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder,
AscendMLAPrefillMetadata,
DecodeMLAPreprocessResult,
PrefillMLAPreprocessResult)
#isort: on
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
wait_for_kv_layer_from_connector)
from vllm_ascend.attention.common_cp import AscendPCPMetadata, CPChunkedContextMetadata
from vllm_ascend.compilation.acl_graph import (get_graph_params,
get_mtp_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, reach_layer_for_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
@@ -75,354 +76,173 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
dtype=torch.uint8,
device=device)
def build(
def set_num_actual_tokens(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
):
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if long_seq_metadata is None:
raise AssertionError("long_seq_metadata should not be None.")
self.num_actual_tokens = max(
long_seq_metadata.num_actual_tokens_pcp_padded,
common_attn_metadata.num_actual_tokens)
def build_cp_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
) -> AscendPCPMetadata | None:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert common_long_seq_metadata is not None
return AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
kv_with_q_head_nomask_idx=common_long_seq_metadata.
kv_with_q_head_nomask_idx_tensor,
kv_with_q_head_mask_idx=common_long_seq_metadata.
kv_with_q_head_mask_idx_tensor,
kv_with_q_tail_nomask_idx=common_long_seq_metadata.
kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=common_long_seq_metadata.
kv_with_q_tail_mask_idx_tensor,
attn_mask_seqlens=common_long_seq_metadata.attn_mask_seqlens,
head_attn_nomask_seqlens=common_long_seq_metadata.
head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx)
def build_chunked_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
):
chunked_context_metadata = super().build_chunked_metadata(
common_prefix_len, common_attn_metadata, model)
if chunked_context_metadata is None:
return None
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if long_seq_metadata is None:
raise AssertionError("long_seq_metadata should not be None.")
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens
# In dcp only spec decode graph padding case,
# num_actual_tokens_pcp_padded may be less than num_actual_tokens
num_actual_tokens_pcp_padded = max(num_actual_tokens_pcp_padded,
num_actual_tokens)
assert long_seq_metadata is not None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None
local_context_lens_allranks = torch.tensor(
num_computed_tokens_of_pcp_dcp[self.num_decodes_flatten:]).reshape(
-1, self.dcp_size * self.pcp_size)
# Note(qcs): The max local context lengths
# padded to `cp_local_block_size`.
padded_local_context_lens_cpu = (cdiv(
self.context_lens_cpu,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
padded_local_max_context_chunk_across_ranks = (cdiv(
self.max_context_chunk,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
local_chunk_starts = (torch.arange(
self.num_chunks, dtype=torch.int32).unsqueeze(1).expand(
-1, self.num_prefills) *
padded_local_max_context_chunk_across_ranks)
local_chunk_ends = torch.min(
padded_local_context_lens_cpu.unsqueeze(0),
local_chunk_starts + padded_local_max_context_chunk_across_ranks,
)
padded_local_chunk_seq_lens = (local_chunk_ends -
local_chunk_starts).clamp(min=0)
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(self.num_chunks,
self.num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(
padded_local_chunk_seq_lens,
dim=1,
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
chunked_metadata = CPChunkedContextMetadata(
cu_seq_lens=chunked_context_metadata.cu_seq_lens,
starts=local_chunk_starts.pin_memory().to(self.device,
non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunked_context_metadata.max_seq_lens,
chunk_seq_lens=self.chunk_seq_lens,
chunk_seq_lens_npu=chunked_context_metadata.chunk_seq_lens_npu,
workspace=chunked_context_metadata.workspace,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(),
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.
pin_memory().to(self.device, non_blocking=True),
cu_seq_lens_lst=self.cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
)
return chunked_metadata
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device
# If graph_pad_size > -1, mean is running in fullgraph mode.
graph_pad_size = common_attn_metadata.graph_pad_size
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_reqs and self.speculative_config.disable_padded_drafter_batch:
block_table = (
common_attn_metadata.block_table_tensor[:graph_pad_size])
else:
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
input_positions = common_attn_metadata.positions[:
num_actual_tokens_pcp_padded].long(
)
if self.cos_cache is None:
self.cos_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.cos_cached
self.sin_cache = model.model.layers[
model.model.start_layer].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
query_lens = query_seq_lens_cpu[:num_reqs]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (seq_lens - query_lens)
def set_prefill_block_table(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
):
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
num_decodes_flatten = query_lens[:num_decodes].sum().item()
block_table = common_attn_metadata.block_table_tensor[:
num_decodes_flatten
+ num_prefills]
self.num_decodes_flatten = self.query_lens[:self.num_decodes].sum(
).item()
self.block_table = common_attn_metadata.block_table_tensor[:self.
num_decodes_flatten
+ self.
num_prefills]
prefill_metadata = None
chunked_context_metadata = None
if num_prefills > 0:
pcp_metadata = AscendMLAPrefillMetadata.AscendPCPMetadata(
q_head_idx=long_seq_metadata.q_head_idx_tensor,
q_tail_idx=long_seq_metadata.q_tail_idx_tensor,
kv_with_q_head_nomask_idx=long_seq_metadata.
kv_with_q_head_nomask_idx_tensor,
kv_with_q_head_mask_idx=long_seq_metadata.
kv_with_q_head_mask_idx_tensor,
kv_with_q_tail_nomask_idx=long_seq_metadata.
kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=long_seq_metadata.
kv_with_q_tail_mask_idx_tensor,
attn_mask_seqlens=long_seq_metadata.attn_mask_seqlens,
head_attn_nomask_seqlens=long_seq_metadata.
head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=long_seq_metadata.q_full_idx,
pcp_prefill_mask=long_seq_metadata.pcp_prefill_mask,
pcp_allgather_restore_idx=long_seq_metadata.
pcp_allgather_restore_idx)
def set_decode_block_table(
self, common_attn_metadata: AscendCommonAttentionMetadata):
self.block_table = self.block_table[:self.num_decodes_flatten, ...]
reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens
max_query_len = query_lens[reqs_start:].max().item()
max_seq_lens = seq_lens[reqs_start:].max().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
def build_prefill_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAPrefillMetadata:
prefill_metadata = super().build_prefill_metadata(
common_prefix_len, common_attn_metadata, model)
prefill_metadata.pcp_metadata = self.build_cp_metadata(
common_prefix_len, common_attn_metadata, model)
prefill_metadata.block_table = self.block_table[
self.num_decodes_flatten:, ...]
return prefill_metadata
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
max_context_chunk = round_down(max_context_chunk,
self.block_size)
def build_decode_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLADecodeMetadata:
decode_metadata = super().build_decode_metadata(
common_prefix_len, common_attn_metadata, model)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert long_seq_metadata is not None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None
# [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp)[:self.num_decodes_flatten]
local_context_lens_allranks = torch.tensor(
num_computed_tokens_of_pcp_dcp[num_decodes_flatten:]
).reshape(-1, self.dcp_size * self.pcp_size)
# Note(qcs): The max local context lengths
# padded to `cp_local_block_size`.
padded_local_context_lens_cpu = (cdiv(
context_lens_cpu,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
padded_local_max_context_chunk_across_ranks = (cdiv(
max_context_chunk,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
local_chunk_starts = (
torch.arange(num_chunks,
dtype=torch.int32).unsqueeze(1).expand(
-1, num_prefills) *
padded_local_max_context_chunk_across_ranks)
local_chunk_ends = torch.min(
padded_local_context_lens_cpu.unsqueeze(0),
local_chunk_starts +
padded_local_max_context_chunk_across_ranks,
)
padded_local_chunk_seq_lens = (local_chunk_ends -
local_chunk_starts).clamp(min=0)
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(
padded_local_chunk_seq_lens,
dim=1,
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
starts=local_chunk_starts.pin_memory().to(
device, non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(
),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.
tolist(),
local_context_lens_allranks=local_context_lens_allranks.
tolist(),
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu
.pin_memory().to(device, non_blocking=True),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
)
prefill_input_positions = input_positions[tokens_start:]
assert self.cos_cache is not None
assert self.sin_cache is not None
cos = self.cos_cache[prefill_input_positions].unsqueeze(
1).unsqueeze(2)
sin = self.sin_cache[prefill_input_positions].unsqueeze(
1).unsqueeze(2)
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens[reqs_start:].to(torch.int32),
seq_lens=seq_lens,
context_lens=seq_lens[reqs_start:],
input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
pcp_metadata=pcp_metadata,
)
prefill_metadata.block_table = \
block_table[num_decodes_flatten:, ...]
decode_metadata = None
if num_decodes > 0:
cos, sin = get_cos_and_sin_mla()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decodes_flatten, ...]
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_decodes and \
self.speculative_config.disable_padded_drafter_batch:
block_table = block_table[:graph_pad_size, ...]
seq_lens_list = seq_lens.tolist()
# [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten]
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank,
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
batch_seq_mask = (cp_seq_len == 0)
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
if graph_pad_size > num_reqs:
if self.speculative_config.disable_padded_drafter_batch:
num_reqs_pad_size = graph_pad_size - num_reqs
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
seq_lens_list = seq_lens_list + [0] * (graph_pad_size -
num_decodes)
num_block_pad_size = graph_pad_size - block_table.shape[0]
if num_block_pad_size > 0:
block_table_padding = torch.zeros(
(num_block_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat(
[block_table, block_table_padding], dim=0)
else:
num_token_pad_size = graph_pad_size - num_decode_tokens
num_reqs_pad_size = (
graph_pad_size //
common_attn_metadata.decode_token_per_req - num_reqs)
num_block_table_pad_size = (
graph_pad_size //
common_attn_metadata.decode_token_per_req -
num_decodes)
seq_lens_list = seq_lens.tolist() + [0] * num_reqs_pad_size
slot_padding = torch.full((num_token_pad_size, ),
PAD_SLOT_ID,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
slot_mapping = torch.cat([slot_mapping, slot_padding])
block_table_padding = torch.zeros(
(num_block_table_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat([block_table, block_table_padding],
dim=0)
position_padding = torch.zeros(
num_token_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat(
[input_positions, position_padding])
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
common_attn_metadata)
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None
assert self.sin_cache is not None
if cos is None and sin is None:
cos = self.cos_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos,
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
else:
cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(
1).unsqueeze(2)
sin[:num_decode_tokens,
...] = self.sin_cache[input_positions].unsqueeze(
1).unsqueeze(2)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...],
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
attn_mask=common_attn_metadata.attn_mask,
attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
)
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank,
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
batch_seq_mask = (cp_seq_len == 0)
self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
decode_metadata.cp_seq_len = cp_seq_len
decode_metadata.batch_seq_mask = batch_seq_mask
return decode_metadata
class AscendMlaCPImpl(AscendMLAImpl):

View File

@@ -9,9 +9,6 @@ from torch import nn
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import (get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_pcp_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
@@ -22,6 +19,8 @@ from vllm.v1.kv_cache_interface import MLAAttentionSpec
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.common_cp import (AscendPCPMetadata,
CPChunkedContextMetadata)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
enable_cp,
maybe_save_kv_layer_to_connector,
@@ -76,44 +75,22 @@ class AscendMLABackend(AttentionBackend):
return AscendMLAImpl
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
@dataclass
class AscendMLAPrefillMetadata:
""" Prefill Specific Metadata for Ascend"""
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
# for mla DCP & PCP
padded_chunk_seq_lens_npu: torch.Tensor = None
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
padded_local_cu_seq_lens: torch.Tensor = None
cu_seq_lens_lst: Optional[list[list[int]]] = None
chunk_size: Optional[int] = None
@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None
attn_mask: torch.Tensor
query_lens: torch.Tensor
seq_lens: list[int]
@@ -123,7 +100,8 @@ class AscendMLAPrefillMetadata:
block_table: torch.Tensor
max_query_len: int
max_seq_lens: int
chunked_context: Optional[ChunkedContextMetadata] = None
chunked_context: Optional[ChunkedContextMetadata
| CPChunkedContextMetadata] = None
sin: torch.Tensor = None
cos: torch.Tensor = None
pcp_metadata: Optional[AscendPCPMetadata] = None
@@ -262,21 +240,21 @@ class AscendMLAMetadataBuilder:
self.cos_cache = None
self.sin_cache = None
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
self.batch_seq_mask_buf = torch.empty(max_num_seqs *
self.decode_threshold,
dtype=torch.uint8,
device=device)
self.chunk_seq_lens: torch.Tensor = None
self.cu_seq_lens_cpu: torch.Tensor = None
self.num_chunks: torch.Tensor = None
self.max_context_chunk = 0
self.num_decodes = 0
self.num_prefills = 0
self.num_decode_tokens = 0
self.num_prefill_tokens = 0
self.context_lens_cpu: torch.Tensor = None
self.num_actual_tokens: Optional[int] = None
self.block_table: torch.Tensor = None
self.slot_mapping: torch.Tensor = None
self.graph_pad_size = 0
self.query_lens: torch.Tensor = None
self.seq_lens: torch.Tensor = None
def reorder_batch(self, input_batch: "NPUInputBatch",
scheduler_output: "SchedulerOutput") -> bool:
@@ -396,6 +374,12 @@ class AscendMLAMetadataBuilder:
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
return actual_seq_lengths_q
def set_num_actual_tokens(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
):
self.num_actual_tokens = common_attn_metadata.num_actual_tokens
def build(
self,
common_prefix_len: int,
@@ -403,41 +387,18 @@ class AscendMLAMetadataBuilder:
model: nn.Module,
) -> AscendMLAMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
self.num_decodes, self.num_prefills, self.num_decode_tokens, self.num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device
# If graph_pad_size > -1, mean is running in fullgraph mode.
graph_pad_size = common_attn_metadata.graph_pad_size
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_reqs and self.speculative_config.disable_padded_drafter_batch:
block_table = (
common_attn_metadata.block_table_tensor[:graph_pad_size])
else:
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens
self.set_num_actual_tokens(common_attn_metadata)
assert self.num_decodes + self.num_prefills == num_reqs
assert self.num_decode_tokens + self.num_prefill_tokens == common_attn_metadata.num_actual_tokens
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
input_positions = common_attn_metadata.positions[:
num_actual_tokens_pcp_padded].long(
)
self.slot_mapping = common_attn_metadata.slot_mapping[:self.
num_actual_tokens]
if self.cos_cache is None:
self.cos_cache = model.model.layers[
@@ -451,210 +412,277 @@ class AscendMLAMetadataBuilder:
self.model_config.dtype) # type: ignore
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
query_lens = query_seq_lens_cpu[:num_reqs]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (seq_lens - query_lens)
self.query_lens = query_seq_lens_cpu[:num_reqs]
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
self.set_prefill_block_table(common_attn_metadata)
prefill_metadata = None
chunked_context_metadata = None
if num_prefills > 0:
pcp_metadata = None
reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens
max_query_len = query_lens[reqs_start:].max().item()
max_seq_lens = seq_lens[reqs_start:].max().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
max_context_chunk = round_down(max_context_chunk,
self.block_size)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata = (
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
device, non_blocking=True),
starts=chunk_starts.pin_memory().to(device,
non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
))
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens[reqs_start:].to(torch.int32),
seq_lens=seq_lens,
context_lens=seq_lens[reqs_start:],
input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
pcp_metadata=pcp_metadata,
)
if self.num_prefills > 0:
prefill_metadata = self.build_prefill_metadata(
common_prefix_len, common_attn_metadata, model)
decode_metadata = None
if num_decodes > 0:
cos, sin = get_cos_and_sin_mla()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decodes, ...]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if graph_pad_size > num_decodes and \
self.speculative_config.disable_padded_drafter_batch:
block_table = block_table[:graph_pad_size, ...]
seq_lens_list = seq_lens.tolist()
cp_seq_len, batch_seq_mask = None, None
if graph_pad_size > num_reqs:
if self.speculative_config.disable_padded_drafter_batch:
num_reqs_pad_size = graph_pad_size - num_reqs
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
seq_lens_list = seq_lens_list + [0] * (graph_pad_size -
num_decodes)
num_block_pad_size = graph_pad_size - block_table.shape[0]
if num_block_pad_size > 0:
block_table_padding = torch.zeros(
(num_block_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat(
[block_table, block_table_padding], dim=0)
else:
num_token_pad_size = graph_pad_size - num_decode_tokens
num_reqs_pad_size = (
graph_pad_size //
common_attn_metadata.decode_token_per_req - num_reqs)
num_block_table_pad_size = (
graph_pad_size //
common_attn_metadata.decode_token_per_req -
num_decodes)
seq_lens_list = seq_lens.tolist() + [0] * num_reqs_pad_size
slot_padding = torch.full((num_token_pad_size, ),
PAD_SLOT_ID,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
slot_mapping = torch.cat([slot_mapping, slot_padding])
block_table_padding = torch.zeros(
(num_block_table_pad_size, ) + block_table.shape[1:],
dtype=block_table.dtype,
device=block_table.device)
block_table = torch.cat([block_table, block_table_padding],
dim=0)
position_padding = torch.zeros(
num_token_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat(
[input_positions, position_padding])
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
common_attn_metadata)
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None
assert self.sin_cache is not None
if cos is None and sin is None:
cos = self.cos_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos,
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
else:
cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(
1).unsqueeze(2)
sin[:num_decode_tokens,
...] = self.sin_cache[input_positions].unsqueeze(
1).unsqueeze(2)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...],
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
if self.num_decodes > 0:
decode_metadata = self.build_decode_metadata(
common_prefix_len, common_attn_metadata, model)
return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_actual_tokens_pcp_padded=self.num_actual_tokens,
num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
slot_mapping=slot_mapping,
num_actual_tokens=self.num_actual_tokens,
query_lens=self.query_lens.tolist(),
slot_mapping=self.slot_mapping,
head_dim=self.model_config.get_head_size(),
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_decodes=self.num_decodes,
num_decode_tokens=self.num_decode_tokens,
num_prefills=self.num_prefills,
attn_mask=common_attn_metadata.attn_mask,
attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
block_tables=self.block_table,
seq_lens=self.seq_lens,
)
def build_chunked_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
):
if not self.chunked_prefill_enabled:
return None
num_reqs = common_attn_metadata.num_reqs
num_computed_tokens_cpu = (self.seq_lens - self.query_lens)
reqs_start = self.num_decodes # prefill_start
self.context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = self.context_lens_cpu.max().item()
if not max_context_len_cpu > 0:
return None
num_prefills_with_context_cpu = (self.context_lens_cpu
> 0).sum().item()
self.max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
self.max_context_chunk = round_down(self.max_context_chunk,
self.block_size)
assert self.max_context_chunk > 0
self.num_chunks = cdiv(max_context_len_cpu, self.max_context_chunk)
chunk_starts = torch.arange(self.num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, self.num_prefills) * self.max_context_chunk
chunk_ends = torch.min(self.context_lens_cpu.unsqueeze(0),
chunk_starts + self.max_context_chunk)
self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
self.cu_seq_lens_cpu = torch.zeros(self.num_chunks,
self.num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(self.chunk_seq_lens,
dim=1,
out=self.cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
return ChunkedContextMetadata(
cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(
self.device, non_blocking=True),
starts=chunk_starts.pin_memory().to(self.device,
non_blocking=True),
seq_tot=self.chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=self.chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=self.chunk_seq_lens,
chunk_seq_lens_npu=self.chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
def set_prefill_block_table(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
):
# If graph_pad_size > -1, mean is running in fullgraph mode.
self.graph_pad_size = common_attn_metadata.graph_pad_size
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch:
self.block_table = (
common_attn_metadata.block_table_tensor[:self.graph_pad_size])
else:
self.block_table = (
common_attn_metadata.block_table_tensor[:common_attn_metadata.
num_reqs])
def set_decode_block_table(
self, common_attn_metadata: AscendCommonAttentionMetadata):
self.block_table = self.block_table[:self.num_decodes, ...]
def build_prefill_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAPrefillMetadata:
query_start_loc = common_attn_metadata.query_start_loc
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
input_positions = common_attn_metadata.positions[:self.
num_actual_tokens].long(
)
chunked_context_metadata = self.build_chunked_metadata(
common_prefix_len, common_attn_metadata, model)
reqs_start = self.num_decodes # prefill_start
tokens_start = self.num_decode_tokens
max_query_len = self.query_lens[reqs_start:].max().item()
max_seq_lens = self.seq_lens[reqs_start:].max().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
return AscendMLAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
query_lens=self.query_lens[reqs_start:].to(torch.int32),
seq_lens=self.seq_lens,
context_lens=self.seq_lens[reqs_start:],
input_positions=prefill_input_positions,
block_table=self.block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
)
def build_decode_metadata(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLADecodeMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
input_positions = common_attn_metadata.positions[:self.
num_actual_tokens].long(
)
cos, sin = get_cos_and_sin_mla()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:self.num_decodes +
1].tolist()
max_seq_lens = self.seq_lens[:self.num_decodes].max().item()
self.seq_lens = self.seq_lens[:self.num_decodes]
input_positions = input_positions[:self.num_decode_tokens]
self.set_decode_block_table(common_attn_metadata)
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if self.graph_pad_size > self.num_decodes and \
self.speculative_config.disable_padded_drafter_batch:
self.block_table = self.block_table[:self.graph_pad_size, ...]
seq_lens_list = self.seq_lens.tolist()
cp_seq_len, batch_seq_mask = None, None
if self.graph_pad_size > num_reqs:
if self.speculative_config.disable_padded_drafter_batch:
num_reqs_pad_size = self.graph_pad_size - num_reqs
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
seq_lens_list = seq_lens_list + [0] * (self.graph_pad_size -
self.num_decodes)
num_block_pad_size = self.graph_pad_size - self.block_table.shape[
0]
if num_block_pad_size > 0:
block_table_padding = torch.zeros(
(num_block_pad_size, ) + self.block_table.shape[1:],
dtype=self.block_table.dtype,
device=self.block_table.device)
self.block_table = torch.cat(
[self.block_table, block_table_padding], dim=0)
else:
num_token_pad_size = self.graph_pad_size - self.num_decode_tokens
num_reqs_pad_size = (
self.graph_pad_size //
common_attn_metadata.decode_token_per_req - num_reqs)
num_block_table_pad_size = (
self.graph_pad_size //
common_attn_metadata.decode_token_per_req -
self.num_decodes)
seq_lens_list = self.seq_lens.tolist() + [0
] * num_reqs_pad_size
slot_padding = torch.full((num_token_pad_size, ),
PAD_SLOT_ID,
dtype=self.slot_mapping.dtype,
device=self.slot_mapping.device)
self.slot_mapping = torch.cat(
[self.slot_mapping, slot_padding])
block_table_padding = torch.zeros(
(num_block_table_pad_size, ) + self.block_table.shape[1:],
dtype=self.block_table.dtype,
device=self.block_table.device)
self.block_table = torch.cat(
[self.block_table, block_table_padding], dim=0)
position_padding = torch.zeros(num_token_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)
input_positions = torch.cat(
[input_positions, position_padding])
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad(
num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
common_attn_metadata)
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None
assert self.sin_cache is not None
if cos is None and sin is None:
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=self.block_table,
seq_lens=self.seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos,
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
else:
cos[:self.num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(1).unsqueeze(
2)
sin[:self.num_decode_tokens,
...] = self.sin_cache[input_positions].unsqueeze(1).unsqueeze(
2)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=self.block_table,
seq_lens=self.seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:self.num_decode_tokens, ...],
cos=cos[:self.num_decode_tokens, ...],
cp_seq_len=cp_seq_len,
batch_seq_mask=batch_seq_mask)
return decode_metadata
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,

View File

@@ -87,7 +87,7 @@ class AscendPrefillContextParallelMetadata:
cp_kv_recover_idx_for_chunk: torch.Tensor = None
num_actual_tokens_pcp_padded: Optional[int] = None
num_actual_tokens_pcp_padded: int = 0
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None