[Refactor]5/N Extract common code of mla_v1.py & extract mla_cp (#5097)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: The functions related to Cp differ significantly from those of normal MLA-Attention, but the coupling is quite severe. Steps: 1)Extract common code AscendMLAMetadataBuilder.build to 4 functions: build_prefill_metadata, build_decode_metadata,build_cp_metadata, build_chunked_metadata todo: 1)refactor function _compute_prefill_context; 2)refactor function _mla_preprocess,_mla_decode_preprocess 3)Extract public data and processing functions from the attention_cp.py and mla_cp.py files to the common_cp file. vLLM version: 0.13.0rc3 vLLM main:ad32e3e19c- vLLM version: 0.13.0rc3 - vLLM main:ad32e3e19c--------- Signed-off-by: wujinyuan1 <wjy9595@qq.com> Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com> Co-authored-by: wujinyuan1 <wjy9595@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -6,8 +6,9 @@ from vllm.distributed.parallel_state import GroupCoordinator
|
|||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
|
from vllm_ascend.attention.common_cp import CPChunkedContextMetadata
|
||||||
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
|
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
|
||||||
from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata
|
from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata
|
||||||
|
|
||||||
|
|
||||||
def get_pcp_split_info(pcp_rank, pcp_size, seq_lens):
|
def get_pcp_split_info(pcp_rank, pcp_size, seq_lens):
|
||||||
@@ -127,7 +128,7 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
|
|||||||
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
chunked_context_metadata = CPChunkedContextMetadata(
|
||||||
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
||||||
starts=local_chunk_starts.to(non_blocking=True),
|
starts=local_chunk_starts.to(non_blocking=True),
|
||||||
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
@@ -144,8 +145,7 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
|
|||||||
chunk_size=padded_local_max_context_chunk_across_ranks,
|
chunk_size=padded_local_max_context_chunk_across_ranks,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chunked_context_metadata = (
|
chunked_context_metadata = (ChunkedContextMetadata(
|
||||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
|
||||||
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
||||||
starts=chunk_starts.to(non_blocking=True),
|
starts=chunk_starts.to(non_blocking=True),
|
||||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
|
|||||||
@@ -14,7 +14,8 @@ from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
|
|||||||
AscendMLADecodeMetadata,
|
AscendMLADecodeMetadata,
|
||||||
AscendMLAImpl, AscendMLAMetadata,
|
AscendMLAImpl, AscendMLAMetadata,
|
||||||
AscendMLAMetadataBuilder,
|
AscendMLAMetadataBuilder,
|
||||||
AscendMLAPrefillMetadata)
|
AscendMLAPrefillMetadata,
|
||||||
|
ChunkedContextMetadata)
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -76,27 +77,15 @@ class TestAscendMLAPrefillMetadata(TestBase):
|
|||||||
max_seq_lens = [2, 2]
|
max_seq_lens = [2, 2]
|
||||||
workspace = torch.randn(2, 4)
|
workspace = torch.randn(2, 4)
|
||||||
chunk_seq_lens = torch.tensor([2, 2])
|
chunk_seq_lens = torch.tensor([2, 2])
|
||||||
padded_chunk_seq_lens_npu = torch.tensor([2, 2])
|
|
||||||
padded_local_chunk_seq_lens = [[2], [2]]
|
|
||||||
local_context_lens_allranks = [[1, 1], [1, 1]]
|
|
||||||
padded_local_cu_seq_lens = torch.tensor([0, 2, 4])
|
|
||||||
cu_seq_lens_lst = [[0, 2], [2, 4]]
|
|
||||||
chunk_size = 2
|
|
||||||
|
|
||||||
chunked_context = AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
chunked_context = ChunkedContextMetadata(
|
||||||
cu_seq_lens=cu_seq_lens,
|
cu_seq_lens=cu_seq_lens,
|
||||||
starts=starts,
|
starts=starts,
|
||||||
seq_tot=seq_tot,
|
seq_tot=seq_tot,
|
||||||
max_seq_lens=max_seq_lens,
|
max_seq_lens=max_seq_lens,
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
chunk_seq_lens=chunk_seq_lens,
|
chunk_seq_lens=chunk_seq_lens,
|
||||||
chunk_seq_lens_npu=chunk_seq_lens,
|
chunk_seq_lens_npu=chunk_seq_lens)
|
||||||
padded_chunk_seq_lens_npu=padded_chunk_seq_lens_npu,
|
|
||||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens,
|
|
||||||
local_context_lens_allranks=local_context_lens_allranks,
|
|
||||||
padded_local_cu_seq_lens=padded_local_cu_seq_lens,
|
|
||||||
cu_seq_lens_lst=cu_seq_lens_lst,
|
|
||||||
chunk_size=chunk_size)
|
|
||||||
|
|
||||||
metadata = AscendMLAPrefillMetadata(
|
metadata = AscendMLAPrefillMetadata(
|
||||||
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
||||||
@@ -119,17 +108,6 @@ class TestAscendMLAPrefillMetadata(TestBase):
|
|||||||
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
||||||
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
|
self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
|
||||||
chunk_seq_lens)
|
chunk_seq_lens)
|
||||||
self.assertIs(metadata.chunked_context.padded_chunk_seq_lens_npu,
|
|
||||||
padded_chunk_seq_lens_npu)
|
|
||||||
self.assertEqual(metadata.chunked_context.padded_local_chunk_seq_lens,
|
|
||||||
padded_local_chunk_seq_lens)
|
|
||||||
self.assertEqual(metadata.chunked_context.local_context_lens_allranks,
|
|
||||||
local_context_lens_allranks)
|
|
||||||
self.assertIs(metadata.chunked_context.padded_local_cu_seq_lens,
|
|
||||||
padded_local_cu_seq_lens)
|
|
||||||
self.assertEqual(metadata.chunked_context.cu_seq_lens_lst,
|
|
||||||
cu_seq_lens_lst)
|
|
||||||
self.assertEqual(metadata.chunked_context.chunk_size, chunk_size)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAscendMLADecodeMetadata(TestBase):
|
class TestAscendMLADecodeMetadata(TestBase):
|
||||||
@@ -218,11 +196,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||||
@patch('vllm.distributed.parallel_state._DCP',
|
@patch('vllm.distributed.parallel_state._DCP',
|
||||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
def test_ascend_mla_metadata_builder_default(self, mock_dcp,
|
||||||
return_value=1)
|
mock_get_dcp_group, mock_pcp,
|
||||||
def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
|
mock_get_pcp_group):
|
||||||
mock_dcp, mock_get_dcp_group,
|
|
||||||
mock_pcp, mock_get_pcp_group):
|
|
||||||
mock_vllm_config = MagicMock()
|
mock_vllm_config = MagicMock()
|
||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
@@ -262,8 +238,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
builder.chunked_prefill_enabled,
|
builder.chunked_prefill_enabled,
|
||||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||||
self.assertEqual(builder.dcp_size, mock_dcp.world_size)
|
|
||||||
self.assertEqual(builder.pcp_size, mock_pcp.world_size)
|
|
||||||
|
|
||||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
@patch('vllm.distributed.parallel_state._PCP',
|
@patch('vllm.distributed.parallel_state._PCP',
|
||||||
@@ -271,10 +245,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||||
@patch('vllm.distributed.parallel_state._DCP',
|
@patch('vllm.distributed.parallel_state._DCP',
|
||||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp,
|
||||||
return_value=1)
|
|
||||||
def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size,
|
|
||||||
mock_dcp,
|
|
||||||
mock_get_dcp_group,
|
mock_get_dcp_group,
|
||||||
mock_pcp,
|
mock_pcp,
|
||||||
mock_get_pcp_group):
|
mock_get_pcp_group):
|
||||||
@@ -324,11 +295,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||||
@patch('vllm.distributed.parallel_state._DCP',
|
@patch('vllm.distributed.parallel_state._DCP',
|
||||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
|
||||||
return_value=1)
|
|
||||||
def test_ascend_mla_metadata_builder_build_full_graph(
|
def test_ascend_mla_metadata_builder_build_full_graph(
|
||||||
self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
|
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group):
|
||||||
mock_get_pcp_group):
|
|
||||||
mock_vllm_config = MagicMock()
|
mock_vllm_config = MagicMock()
|
||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
@@ -387,10 +355,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||||
@patch('vllm.distributed.parallel_state._DCP',
|
@patch('vllm.distributed.parallel_state._DCP',
|
||||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
def test_reorder_batch(self, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||||
return_value=1)
|
mock_get_pcp_group):
|
||||||
def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
|
|
||||||
mock_get_dcp_group, mock_pcp, mock_get_pcp_group):
|
|
||||||
ascend_config = MagicMock()
|
ascend_config = MagicMock()
|
||||||
|
|
||||||
mock_vllm_config = MagicMock()
|
mock_vllm_config = MagicMock()
|
||||||
@@ -448,10 +414,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||||
@patch('vllm.distributed.parallel_state._DCP',
|
@patch('vllm.distributed.parallel_state._DCP',
|
||||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_dcp,
|
||||||
return_value=1)
|
|
||||||
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_get_dcp_size,
|
|
||||||
mock_dcp,
|
|
||||||
mock_get_dcp_group,
|
mock_get_dcp_group,
|
||||||
mock_pcp,
|
mock_pcp,
|
||||||
mock_get_pcp_group):
|
mock_get_pcp_group):
|
||||||
@@ -496,11 +459,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||||
@patch('vllm.distributed.parallel_state._DCP',
|
@patch('vllm.distributed.parallel_state._DCP',
|
||||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_dcp,
|
||||||
return_value=1)
|
mock_get_dcp_group, mock_pcp,
|
||||||
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_get_dcp_size,
|
|
||||||
mock_dcp, mock_get_dcp_group,
|
|
||||||
mock_pcp,
|
|
||||||
mock_get_pcp_group):
|
mock_get_pcp_group):
|
||||||
mock_vllm_config = MagicMock()
|
mock_vllm_config = MagicMock()
|
||||||
mock_vllm_config.model_config.max_model_len = 1024
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
@@ -566,17 +526,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
self.kv_cache_spec.head_size = 128
|
self.kv_cache_spec.head_size = 128
|
||||||
self.kv_cache_spec.num_heads = 32
|
self.kv_cache_spec.num_heads = 32
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
@patch(
|
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
return_value=1)
|
||||||
)
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||||
@patch("torch.npu.is_available")
|
@patch("torch.npu.is_available")
|
||||||
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
||||||
mock_zeros, mock_get_ascend_config,
|
mock_zeros, mock_dcp_world_size,
|
||||||
mock_dcp_world_size,
|
|
||||||
mock_get_pcp_group):
|
mock_get_pcp_group):
|
||||||
mock_npu_available.return_value = False
|
mock_npu_available.return_value = False
|
||||||
mock_dcp_world_size.return_value = 1
|
mock_dcp_world_size.return_value = 1
|
||||||
@@ -633,17 +590,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
@patch(
|
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
return_value=1)
|
||||||
)
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||||
@patch("torch.npu.is_available")
|
@patch("torch.npu.is_available")
|
||||||
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
||||||
mock_zeros, mock_get_ascend_config,
|
mock_zeros, mock_dcp_world_size,
|
||||||
mock_dcp_world_size,
|
|
||||||
mock_get_pcp_group):
|
mock_get_pcp_group):
|
||||||
mock_npu_available.return_value = False
|
mock_npu_available.return_value = False
|
||||||
mock_dcp_world_size.return_value = 1
|
mock_dcp_world_size.return_value = 1
|
||||||
@@ -701,13 +655,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
@patch(
|
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
return_value=1)
|
||||||
)
|
def test_build_decode_only_metadata(self, mock_dcp_world_size,
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
|
||||||
def test_build_decode_only_metadata(self, mock_get_ascend_config,
|
|
||||||
mock_dcp_world_size,
|
|
||||||
mock_get_pcp_group):
|
mock_get_pcp_group):
|
||||||
mock_dcp_world_size.return_value = 1
|
mock_dcp_world_size.return_value = 1
|
||||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||||
@@ -757,13 +708,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
@patch(
|
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
return_value=1)
|
||||||
)
|
def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
|
||||||
def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config,
|
|
||||||
mock_dcp_world_size,
|
|
||||||
mock_get_pcp_group):
|
mock_get_pcp_group):
|
||||||
mock_dcp_world_size.return_value = 1
|
mock_dcp_world_size.return_value = 1
|
||||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||||
@@ -814,13 +762,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
|||||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||||
@patch(
|
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
return_value=1)
|
||||||
)
|
def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
|
||||||
def test_build_for_graph_capture_prefill(self, mock_get_ascend_config,
|
|
||||||
mock_dcp_world_size,
|
|
||||||
mock_get_pcp_group):
|
mock_get_pcp_group):
|
||||||
mock_dcp_world_size.return_value = 1
|
mock_dcp_world_size.return_value = 1
|
||||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||||
@@ -868,16 +813,10 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch('vllm.distributed.parallel_state._DCP',
|
@patch('vllm.distributed.parallel_state._DCP',
|
||||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
|
||||||
return_value=1)
|
|
||||||
@patch('vllm.distributed.parallel_state._TP',
|
@patch('vllm.distributed.parallel_state._TP',
|
||||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
|
||||||
return_value=2)
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
||||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
def setUp(self, get_current_vllm_config, mock_tp, mock_dcp, mock_pcp):
|
||||||
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
|
|
||||||
mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp):
|
|
||||||
mock_tp.world_size = 2
|
mock_tp.world_size = 2
|
||||||
mock_tp.rank_in_group = MagicMock()
|
mock_tp.rank_in_group = MagicMock()
|
||||||
mock_tp.device_group = MagicMock()
|
mock_tp.device_group = MagicMock()
|
||||||
|
|||||||
40
vllm_ascend/attention/common_cp.py
Normal file
40
vllm_ascend/attention/common_cp.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendPCPMetadata:
|
||||||
|
q_head_idx: torch.Tensor = None
|
||||||
|
q_tail_idx: torch.Tensor = None
|
||||||
|
kv_with_q_head_nomask_idx: torch.Tensor = None
|
||||||
|
kv_with_q_head_mask_idx: torch.Tensor = None
|
||||||
|
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
||||||
|
kv_with_q_tail_mask_idx: torch.Tensor = None
|
||||||
|
attn_mask_seqlens: torch.Tensor = None
|
||||||
|
head_attn_nomask_seqlens: torch.Tensor = None
|
||||||
|
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||||
|
q_full_idx: torch.Tensor = None
|
||||||
|
pcp_prefill_mask: torch.Tensor = None
|
||||||
|
pcp_allgather_restore_idx: Optional[list[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CPChunkedContextMetadata:
|
||||||
|
# New for MLA (compared to FlashAttention)
|
||||||
|
# For handling chunked prefill
|
||||||
|
cu_seq_lens: torch.Tensor
|
||||||
|
starts: torch.Tensor
|
||||||
|
seq_tot: list[int]
|
||||||
|
max_seq_lens: list[int]
|
||||||
|
workspace: torch.Tensor
|
||||||
|
chunk_seq_lens: torch.Tensor
|
||||||
|
chunk_seq_lens_npu: torch.Tensor
|
||||||
|
# for mla DCP & PCP
|
||||||
|
padded_chunk_seq_lens_npu: torch.Tensor = None
|
||||||
|
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
|
||||||
|
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||||
|
padded_local_cu_seq_lens: torch.Tensor = None
|
||||||
|
cu_seq_lens_lst: Optional[list[list[int]]] = None
|
||||||
|
chunk_size: Optional[int] = None
|
||||||
@@ -5,31 +5,32 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (get_dcp_group,
|
from vllm.distributed import (get_dcp_group,
|
||||||
get_decode_context_model_parallel_rank,
|
get_decode_context_model_parallel_rank,
|
||||||
get_decode_context_model_parallel_world_size,
|
get_decode_context_model_parallel_world_size,
|
||||||
get_pcp_group)
|
get_pcp_group)
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.utils.math_utils import cdiv, round_down
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||||
|
|
||||||
|
# isort: off
|
||||||
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
|
||||||
AscendMLAImpl, AscendMLAMetadata,
|
AscendMLAImpl, AscendMLAMetadata,
|
||||||
AscendMLAMetadataBuilder,
|
AscendMLAMetadataBuilder,
|
||||||
AscendMLAPrefillMetadata,
|
AscendMLAPrefillMetadata,
|
||||||
DecodeMLAPreprocessResult,
|
DecodeMLAPreprocessResult,
|
||||||
PrefillMLAPreprocessResult)
|
PrefillMLAPreprocessResult)
|
||||||
|
#isort: on
|
||||||
|
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
maybe_save_kv_layer_to_connector,
|
maybe_save_kv_layer_to_connector,
|
||||||
split_decodes_and_prefills,
|
|
||||||
wait_for_kv_layer_from_connector)
|
wait_for_kv_layer_from_connector)
|
||||||
|
from vllm_ascend.attention.common_cp import AscendPCPMetadata, CPChunkedContextMetadata
|
||||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||||
get_mtp_graph_params,
|
get_mtp_graph_params,
|
||||||
update_graph_params_workspaces)
|
update_graph_params_workspaces)
|
||||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
|
||||||
from vllm_ascend.ops.shared_weight_layer import (
|
from vllm_ascend.ops.shared_weight_layer import (
|
||||||
is_hidden_layer, reach_layer_for_shared_weight_series)
|
is_hidden_layer, reach_layer_for_shared_weight_series)
|
||||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
@@ -75,162 +76,87 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
|||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
def build(
|
def set_num_actual_tokens(
|
||||||
|
self,
|
||||||
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
):
|
||||||
|
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||||
|
if long_seq_metadata is None:
|
||||||
|
raise AssertionError("long_seq_metadata should not be None.")
|
||||||
|
|
||||||
|
self.num_actual_tokens = max(
|
||||||
|
long_seq_metadata.num_actual_tokens_pcp_padded,
|
||||||
|
common_attn_metadata.num_actual_tokens)
|
||||||
|
|
||||||
|
def build_cp_metadata(
|
||||||
self,
|
self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
) -> AscendMLAMetadata:
|
) -> AscendPCPMetadata | None:
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
assert common_long_seq_metadata is not None
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
return AscendPCPMetadata(
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
|
||||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
|
||||||
|
kv_with_q_head_nomask_idx=common_long_seq_metadata.
|
||||||
if long_seq_metadata is None:
|
|
||||||
raise AssertionError("long_seq_metadata should not be None.")
|
|
||||||
|
|
||||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
|
|
||||||
if num_actual_tokens_pcp_padded is None:
|
|
||||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
|
||||||
# In dcp only spec decode graph padding case,
|
|
||||||
# num_actual_tokens_pcp_padded may be less than num_actual_tokens
|
|
||||||
num_actual_tokens_pcp_padded = max(num_actual_tokens_pcp_padded,
|
|
||||||
num_actual_tokens)
|
|
||||||
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
|
||||||
assert num_computed_tokens_of_pcp_dcp is not None
|
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
|
||||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
|
||||||
assert num_decodes + num_prefills == num_reqs
|
|
||||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
|
||||||
|
|
||||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
|
||||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
|
||||||
# it blocks on all previous kernels.
|
|
||||||
device = self.device
|
|
||||||
|
|
||||||
# If graph_pad_size > -1, mean is running in fullgraph mode.
|
|
||||||
graph_pad_size = common_attn_metadata.graph_pad_size
|
|
||||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
|
||||||
if graph_pad_size > num_reqs and self.speculative_config.disable_padded_drafter_batch:
|
|
||||||
block_table = (
|
|
||||||
common_attn_metadata.block_table_tensor[:graph_pad_size])
|
|
||||||
else:
|
|
||||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
|
||||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
|
||||||
num_actual_tokens_pcp_padded]
|
|
||||||
input_positions = common_attn_metadata.positions[:
|
|
||||||
num_actual_tokens_pcp_padded].long(
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cos_cache is None:
|
|
||||||
self.cos_cache = model.model.layers[
|
|
||||||
model.model.start_layer].self_attn.rotary_emb.cos_cached
|
|
||||||
self.sin_cache = model.model.layers[
|
|
||||||
model.model.start_layer].self_attn.rotary_emb.sin_cached
|
|
||||||
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
|
|
||||||
self.cos_cache = self.cos_cache.to( # type: ignore
|
|
||||||
self.model_config.dtype) # type: ignore
|
|
||||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
|
||||||
self.model_config.dtype) # type: ignore
|
|
||||||
|
|
||||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
|
||||||
query_lens = query_seq_lens_cpu[:num_reqs]
|
|
||||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
|
||||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
|
||||||
|
|
||||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
|
||||||
# to avoid irregular spec_attn_mask shape
|
|
||||||
num_decodes_flatten = query_lens[:num_decodes].sum().item()
|
|
||||||
block_table = common_attn_metadata.block_table_tensor[:
|
|
||||||
num_decodes_flatten
|
|
||||||
+ num_prefills]
|
|
||||||
|
|
||||||
prefill_metadata = None
|
|
||||||
chunked_context_metadata = None
|
|
||||||
if num_prefills > 0:
|
|
||||||
pcp_metadata = AscendMLAPrefillMetadata.AscendPCPMetadata(
|
|
||||||
q_head_idx=long_seq_metadata.q_head_idx_tensor,
|
|
||||||
q_tail_idx=long_seq_metadata.q_tail_idx_tensor,
|
|
||||||
kv_with_q_head_nomask_idx=long_seq_metadata.
|
|
||||||
kv_with_q_head_nomask_idx_tensor,
|
kv_with_q_head_nomask_idx_tensor,
|
||||||
kv_with_q_head_mask_idx=long_seq_metadata.
|
kv_with_q_head_mask_idx=common_long_seq_metadata.
|
||||||
kv_with_q_head_mask_idx_tensor,
|
kv_with_q_head_mask_idx_tensor,
|
||||||
kv_with_q_tail_nomask_idx=long_seq_metadata.
|
kv_with_q_tail_nomask_idx=common_long_seq_metadata.
|
||||||
kv_with_q_tail_nomask_idx_tensor,
|
kv_with_q_tail_nomask_idx_tensor,
|
||||||
kv_with_q_tail_mask_idx=long_seq_metadata.
|
kv_with_q_tail_mask_idx=common_long_seq_metadata.
|
||||||
kv_with_q_tail_mask_idx_tensor,
|
kv_with_q_tail_mask_idx_tensor,
|
||||||
attn_mask_seqlens=long_seq_metadata.attn_mask_seqlens,
|
attn_mask_seqlens=common_long_seq_metadata.attn_mask_seqlens,
|
||||||
head_attn_nomask_seqlens=long_seq_metadata.
|
head_attn_nomask_seqlens=common_long_seq_metadata.
|
||||||
head_attn_nomask_seqlens,
|
head_attn_nomask_seqlens,
|
||||||
tail_attn_nomask_seqlens=long_seq_metadata.
|
tail_attn_nomask_seqlens=common_long_seq_metadata.
|
||||||
tail_attn_nomask_seqlens,
|
tail_attn_nomask_seqlens,
|
||||||
q_full_idx=long_seq_metadata.q_full_idx,
|
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||||
pcp_prefill_mask=long_seq_metadata.pcp_prefill_mask,
|
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
|
||||||
pcp_allgather_restore_idx=long_seq_metadata.
|
pcp_allgather_restore_idx=common_long_seq_metadata.
|
||||||
pcp_allgather_restore_idx)
|
pcp_allgather_restore_idx)
|
||||||
|
|
||||||
reqs_start = num_decodes # prefill_start
|
def build_chunked_metadata(
|
||||||
tokens_start = num_decode_tokens
|
self,
|
||||||
max_query_len = query_lens[reqs_start:].max().item()
|
common_prefix_len: int,
|
||||||
max_seq_lens = seq_lens[reqs_start:].max().item()
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
prefill_query_start_loc = query_start_loc[
|
model: nn.Module,
|
||||||
reqs_start:] - query_start_loc[reqs_start]
|
):
|
||||||
|
chunked_context_metadata = super().build_chunked_metadata(
|
||||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
common_prefix_len, common_attn_metadata, model)
|
||||||
max_context_len_cpu = context_lens_cpu.max().item()
|
if chunked_context_metadata is None:
|
||||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
return None
|
||||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
|
||||||
max_context_chunk = (self.chunked_prefill_workspace_size //
|
|
||||||
num_prefills_with_context_cpu)
|
|
||||||
max_context_chunk = round_down(max_context_chunk,
|
|
||||||
self.block_size)
|
|
||||||
|
|
||||||
assert max_context_chunk > 0
|
|
||||||
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
|
||||||
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
|
||||||
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
|
|
||||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
|
||||||
chunk_starts + max_context_chunk)
|
|
||||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
|
||||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
|
||||||
num_prefills + 1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
pin_memory=True)
|
|
||||||
torch.cumsum(chunk_seq_lens,
|
|
||||||
dim=1,
|
|
||||||
out=cu_seq_lens_cpu[:, 1:],
|
|
||||||
dtype=torch.int32)
|
|
||||||
|
|
||||||
|
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||||
|
assert long_seq_metadata is not None
|
||||||
|
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||||
|
assert num_computed_tokens_of_pcp_dcp is not None
|
||||||
local_context_lens_allranks = torch.tensor(
|
local_context_lens_allranks = torch.tensor(
|
||||||
num_computed_tokens_of_pcp_dcp[num_decodes_flatten:]
|
num_computed_tokens_of_pcp_dcp[self.num_decodes_flatten:]).reshape(
|
||||||
).reshape(-1, self.dcp_size * self.pcp_size)
|
-1, self.dcp_size * self.pcp_size)
|
||||||
# Note(qcs): The max local context lengths
|
# Note(qcs): The max local context lengths
|
||||||
# padded to `cp_local_block_size`.
|
# padded to `cp_local_block_size`.
|
||||||
padded_local_context_lens_cpu = (cdiv(
|
padded_local_context_lens_cpu = (cdiv(
|
||||||
context_lens_cpu,
|
self.context_lens_cpu,
|
||||||
self.cp_virtual_block_size,
|
self.cp_virtual_block_size,
|
||||||
) * self.cp_local_block_size)
|
) * self.cp_local_block_size)
|
||||||
padded_local_max_context_chunk_across_ranks = (cdiv(
|
padded_local_max_context_chunk_across_ranks = (cdiv(
|
||||||
max_context_chunk,
|
self.max_context_chunk,
|
||||||
self.cp_virtual_block_size,
|
self.cp_virtual_block_size,
|
||||||
) * self.cp_local_block_size)
|
) * self.cp_local_block_size)
|
||||||
local_chunk_starts = (
|
local_chunk_starts = (torch.arange(
|
||||||
torch.arange(num_chunks,
|
self.num_chunks, dtype=torch.int32).unsqueeze(1).expand(
|
||||||
dtype=torch.int32).unsqueeze(1).expand(
|
-1, self.num_prefills) *
|
||||||
-1, num_prefills) *
|
|
||||||
padded_local_max_context_chunk_across_ranks)
|
padded_local_max_context_chunk_across_ranks)
|
||||||
local_chunk_ends = torch.min(
|
local_chunk_ends = torch.min(
|
||||||
padded_local_context_lens_cpu.unsqueeze(0),
|
padded_local_context_lens_cpu.unsqueeze(0),
|
||||||
local_chunk_starts +
|
local_chunk_starts + padded_local_max_context_chunk_across_ranks,
|
||||||
padded_local_max_context_chunk_across_ranks,
|
|
||||||
)
|
)
|
||||||
padded_local_chunk_seq_lens = (local_chunk_ends -
|
padded_local_chunk_seq_lens = (local_chunk_ends -
|
||||||
local_chunk_starts).clamp(min=0)
|
local_chunk_starts).clamp(min=0)
|
||||||
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
|
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(self.num_chunks,
|
||||||
num_chunks,
|
self.num_prefills + 1,
|
||||||
num_prefills + 1,
|
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
torch.cumsum(
|
torch.cumsum(
|
||||||
@@ -239,72 +165,72 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
|||||||
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
chunked_metadata = CPChunkedContextMetadata(
|
||||||
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
|
cu_seq_lens=chunked_context_metadata.cu_seq_lens,
|
||||||
device, non_blocking=True),
|
starts=local_chunk_starts.pin_memory().to(self.device,
|
||||||
starts=local_chunk_starts.pin_memory().to(
|
non_blocking=True),
|
||||||
device, non_blocking=True),
|
|
||||||
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
max_seq_lens=chunked_context_metadata.max_seq_lens,
|
||||||
chunk_seq_lens=chunk_seq_lens,
|
chunk_seq_lens=self.chunk_seq_lens,
|
||||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
chunk_seq_lens_npu=chunked_context_metadata.chunk_seq_lens_npu,
|
||||||
workspace=self.chunked_prefill_workspace,
|
workspace=chunked_context_metadata.workspace,
|
||||||
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(
|
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
|
||||||
),
|
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
||||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.
|
local_context_lens_allranks=local_context_lens_allranks.tolist(),
|
||||||
tolist(),
|
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.
|
||||||
local_context_lens_allranks=local_context_lens_allranks.
|
pin_memory().to(self.device, non_blocking=True),
|
||||||
tolist(),
|
cu_seq_lens_lst=self.cu_seq_lens_cpu.tolist(),
|
||||||
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu
|
|
||||||
.pin_memory().to(device, non_blocking=True),
|
|
||||||
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
|
||||||
chunk_size=padded_local_max_context_chunk_across_ranks,
|
chunk_size=padded_local_max_context_chunk_across_ranks,
|
||||||
)
|
)
|
||||||
|
return chunked_metadata
|
||||||
|
|
||||||
prefill_input_positions = input_positions[tokens_start:]
|
def set_prefill_block_table(
|
||||||
assert self.cos_cache is not None
|
self,
|
||||||
assert self.sin_cache is not None
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
cos = self.cos_cache[prefill_input_positions].unsqueeze(
|
):
|
||||||
1).unsqueeze(2)
|
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||||
sin = self.sin_cache[prefill_input_positions].unsqueeze(
|
# to avoid irregular spec_attn_mask shape
|
||||||
1).unsqueeze(2)
|
self.num_decodes_flatten = self.query_lens[:self.num_decodes].sum(
|
||||||
prefill_metadata = AscendMLAPrefillMetadata(
|
).item()
|
||||||
attn_mask=common_attn_metadata.attn_mask,
|
self.block_table = common_attn_metadata.block_table_tensor[:self.
|
||||||
query_lens=query_lens[reqs_start:].to(torch.int32),
|
num_decodes_flatten
|
||||||
seq_lens=seq_lens,
|
+ self.
|
||||||
context_lens=seq_lens[reqs_start:],
|
num_prefills]
|
||||||
input_positions=prefill_input_positions,
|
|
||||||
block_table=block_table[reqs_start:, ...],
|
|
||||||
max_query_len=max_query_len,
|
|
||||||
max_seq_lens=max_seq_lens,
|
|
||||||
query_start_loc=prefill_query_start_loc,
|
|
||||||
chunked_context=chunked_context_metadata,
|
|
||||||
sin=sin,
|
|
||||||
cos=cos,
|
|
||||||
pcp_metadata=pcp_metadata,
|
|
||||||
)
|
|
||||||
prefill_metadata.block_table = \
|
|
||||||
block_table[num_decodes_flatten:, ...]
|
|
||||||
|
|
||||||
decode_metadata = None
|
def set_decode_block_table(
|
||||||
if num_decodes > 0:
|
self, common_attn_metadata: AscendCommonAttentionMetadata):
|
||||||
cos, sin = get_cos_and_sin_mla()
|
self.block_table = self.block_table[:self.num_decodes_flatten, ...]
|
||||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
|
||||||
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
|
|
||||||
1].tolist()
|
|
||||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
|
||||||
seq_lens = seq_lens[:num_decodes]
|
|
||||||
input_positions = input_positions[:num_decode_tokens]
|
|
||||||
block_table = block_table[:num_decodes_flatten, ...]
|
|
||||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
|
||||||
if graph_pad_size > num_decodes and \
|
|
||||||
self.speculative_config.disable_padded_drafter_batch:
|
|
||||||
block_table = block_table[:graph_pad_size, ...]
|
|
||||||
seq_lens_list = seq_lens.tolist()
|
|
||||||
|
|
||||||
|
def build_prefill_metadata(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
model: nn.Module,
|
||||||
|
) -> AscendMLAPrefillMetadata:
|
||||||
|
prefill_metadata = super().build_prefill_metadata(
|
||||||
|
common_prefix_len, common_attn_metadata, model)
|
||||||
|
prefill_metadata.pcp_metadata = self.build_cp_metadata(
|
||||||
|
common_prefix_len, common_attn_metadata, model)
|
||||||
|
prefill_metadata.block_table = self.block_table[
|
||||||
|
self.num_decodes_flatten:, ...]
|
||||||
|
return prefill_metadata
|
||||||
|
|
||||||
|
def build_decode_metadata(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
model: nn.Module,
|
||||||
|
) -> AscendMLADecodeMetadata:
|
||||||
|
decode_metadata = super().build_decode_metadata(
|
||||||
|
common_prefix_len, common_attn_metadata, model)
|
||||||
|
|
||||||
|
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||||
|
assert long_seq_metadata is not None
|
||||||
|
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||||
|
assert num_computed_tokens_of_pcp_dcp is not None
|
||||||
# [bs, pcp_size, dcp_size]
|
# [bs, pcp_size, dcp_size]
|
||||||
num_computed_tokens_of_cp_dcp_array = np.array(
|
num_computed_tokens_of_cp_dcp_array = np.array(
|
||||||
num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten]
|
num_computed_tokens_of_pcp_dcp)[:self.num_decodes_flatten]
|
||||||
|
|
||||||
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank,
|
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank,
|
||||||
self.dcp_rank]
|
self.dcp_rank]
|
||||||
@@ -314,115 +240,9 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
|||||||
batch_seq_mask, non_blocking=True)
|
batch_seq_mask, non_blocking=True)
|
||||||
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]]
|
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]]
|
||||||
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
|
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
|
||||||
|
decode_metadata.cp_seq_len = cp_seq_len
|
||||||
if graph_pad_size > num_reqs:
|
decode_metadata.batch_seq_mask = batch_seq_mask
|
||||||
if self.speculative_config.disable_padded_drafter_batch:
|
return decode_metadata
|
||||||
num_reqs_pad_size = graph_pad_size - num_reqs
|
|
||||||
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
|
|
||||||
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
|
|
||||||
seq_lens_list = seq_lens_list + [0] * (graph_pad_size -
|
|
||||||
num_decodes)
|
|
||||||
num_block_pad_size = graph_pad_size - block_table.shape[0]
|
|
||||||
if num_block_pad_size > 0:
|
|
||||||
block_table_padding = torch.zeros(
|
|
||||||
(num_block_pad_size, ) + block_table.shape[1:],
|
|
||||||
dtype=block_table.dtype,
|
|
||||||
device=block_table.device)
|
|
||||||
block_table = torch.cat(
|
|
||||||
[block_table, block_table_padding], dim=0)
|
|
||||||
else:
|
|
||||||
num_token_pad_size = graph_pad_size - num_decode_tokens
|
|
||||||
num_reqs_pad_size = (
|
|
||||||
graph_pad_size //
|
|
||||||
common_attn_metadata.decode_token_per_req - num_reqs)
|
|
||||||
num_block_table_pad_size = (
|
|
||||||
graph_pad_size //
|
|
||||||
common_attn_metadata.decode_token_per_req -
|
|
||||||
num_decodes)
|
|
||||||
seq_lens_list = seq_lens.tolist() + [0] * num_reqs_pad_size
|
|
||||||
slot_padding = torch.full((num_token_pad_size, ),
|
|
||||||
PAD_SLOT_ID,
|
|
||||||
dtype=slot_mapping.dtype,
|
|
||||||
device=slot_mapping.device)
|
|
||||||
slot_mapping = torch.cat([slot_mapping, slot_padding])
|
|
||||||
block_table_padding = torch.zeros(
|
|
||||||
(num_block_table_pad_size, ) + block_table.shape[1:],
|
|
||||||
dtype=block_table.dtype,
|
|
||||||
device=block_table.device)
|
|
||||||
block_table = torch.cat([block_table, block_table_padding],
|
|
||||||
dim=0)
|
|
||||||
position_padding = torch.zeros(
|
|
||||||
num_token_pad_size,
|
|
||||||
dtype=input_positions.dtype,
|
|
||||||
device=input_positions.device)
|
|
||||||
input_positions = torch.cat(
|
|
||||||
[input_positions, position_padding])
|
|
||||||
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad(
|
|
||||||
num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
|
|
||||||
common_attn_metadata)
|
|
||||||
|
|
||||||
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
|
||||||
assert self.cos_cache is not None
|
|
||||||
assert self.sin_cache is not None
|
|
||||||
if cos is None and sin is None:
|
|
||||||
cos = self.cos_cache[
|
|
||||||
input_positions].unsqueeze( # type: ignore
|
|
||||||
1).unsqueeze(2)
|
|
||||||
sin = self.sin_cache[
|
|
||||||
input_positions].unsqueeze( # type: ignore
|
|
||||||
1).unsqueeze(2)
|
|
||||||
|
|
||||||
decode_metadata = AscendMLADecodeMetadata(
|
|
||||||
input_positions=input_positions,
|
|
||||||
block_table=block_table,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
seq_lens_list=seq_lens_list,
|
|
||||||
max_seq_lens=max_seq_lens,
|
|
||||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
|
||||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
|
||||||
sin=sin,
|
|
||||||
cos=cos,
|
|
||||||
cp_seq_len=cp_seq_len,
|
|
||||||
batch_seq_mask=batch_seq_mask)
|
|
||||||
else:
|
|
||||||
cos[:num_decode_tokens,
|
|
||||||
...] = self.cos_cache[input_positions].unsqueeze(
|
|
||||||
1).unsqueeze(2)
|
|
||||||
sin[:num_decode_tokens,
|
|
||||||
...] = self.sin_cache[input_positions].unsqueeze(
|
|
||||||
1).unsqueeze(2)
|
|
||||||
|
|
||||||
decode_metadata = AscendMLADecodeMetadata(
|
|
||||||
input_positions=input_positions,
|
|
||||||
block_table=block_table,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
seq_lens_list=seq_lens_list,
|
|
||||||
max_seq_lens=max_seq_lens,
|
|
||||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
|
||||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
|
||||||
sin=sin[:num_decode_tokens, ...],
|
|
||||||
cos=cos[:num_decode_tokens, ...],
|
|
||||||
cp_seq_len=cp_seq_len,
|
|
||||||
batch_seq_mask=batch_seq_mask)
|
|
||||||
|
|
||||||
return self.metadata_cls( # type: ignore
|
|
||||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
|
||||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
|
||||||
num_actual_tokens=num_actual_tokens,
|
|
||||||
query_lens=query_lens.tolist(),
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
head_dim=self.model_config.get_head_size(),
|
|
||||||
num_decodes=num_decodes,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
attn_mask=common_attn_metadata.attn_mask,
|
|
||||||
attn_state=common_attn_metadata.attn_state,
|
|
||||||
prefill=prefill_metadata,
|
|
||||||
decode=decode_metadata,
|
|
||||||
query_start_loc=query_start_loc,
|
|
||||||
block_tables=block_table,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AscendMlaCPImpl(AscendMLAImpl):
|
class AscendMlaCPImpl(AscendMLAImpl):
|
||||||
|
|||||||
@@ -9,9 +9,6 @@ from torch import nn
|
|||||||
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed import (get_decode_context_model_parallel_rank,
|
|
||||||
get_decode_context_model_parallel_world_size,
|
|
||||||
get_pcp_group)
|
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
@@ -22,6 +19,8 @@ from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
|||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
|
from vllm_ascend.attention.common_cp import (AscendPCPMetadata,
|
||||||
|
CPChunkedContextMetadata)
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
enable_cp,
|
enable_cp,
|
||||||
maybe_save_kv_layer_to_connector,
|
maybe_save_kv_layer_to_connector,
|
||||||
@@ -77,11 +76,7 @@ class AscendMLABackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendMLAPrefillMetadata:
|
class ChunkedContextMetadata:
|
||||||
""" Prefill Specific Metadata for Ascend"""
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ChunkedContextMetadata:
|
|
||||||
# New for MLA (compared to FlashAttention)
|
# New for MLA (compared to FlashAttention)
|
||||||
# For handling chunked prefill
|
# For handling chunked prefill
|
||||||
cu_seq_lens: torch.Tensor
|
cu_seq_lens: torch.Tensor
|
||||||
@@ -91,29 +86,11 @@ class AscendMLAPrefillMetadata:
|
|||||||
workspace: torch.Tensor
|
workspace: torch.Tensor
|
||||||
chunk_seq_lens: torch.Tensor
|
chunk_seq_lens: torch.Tensor
|
||||||
chunk_seq_lens_npu: torch.Tensor
|
chunk_seq_lens_npu: torch.Tensor
|
||||||
# for mla DCP & PCP
|
|
||||||
padded_chunk_seq_lens_npu: torch.Tensor = None
|
|
||||||
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
|
|
||||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
|
||||||
padded_local_cu_seq_lens: torch.Tensor = None
|
|
||||||
cu_seq_lens_lst: Optional[list[list[int]]] = None
|
|
||||||
chunk_size: Optional[int] = None
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AscendPCPMetadata:
|
|
||||||
q_head_idx: torch.Tensor = None
|
|
||||||
q_tail_idx: torch.Tensor = None
|
|
||||||
kv_with_q_head_nomask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_head_mask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
|
||||||
kv_with_q_tail_mask_idx: torch.Tensor = None
|
|
||||||
attn_mask_seqlens: torch.Tensor = None
|
|
||||||
head_attn_nomask_seqlens: torch.Tensor = None
|
|
||||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
|
||||||
q_full_idx: torch.Tensor = None
|
|
||||||
pcp_prefill_mask: torch.Tensor = None
|
|
||||||
pcp_allgather_restore_idx: Optional[list[int]] = None
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendMLAPrefillMetadata:
|
||||||
|
""" Prefill Specific Metadata for Ascend"""
|
||||||
attn_mask: torch.Tensor
|
attn_mask: torch.Tensor
|
||||||
query_lens: torch.Tensor
|
query_lens: torch.Tensor
|
||||||
seq_lens: list[int]
|
seq_lens: list[int]
|
||||||
@@ -123,7 +100,8 @@ class AscendMLAPrefillMetadata:
|
|||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
max_query_len: int
|
max_query_len: int
|
||||||
max_seq_lens: int
|
max_seq_lens: int
|
||||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
chunked_context: Optional[ChunkedContextMetadata
|
||||||
|
| CPChunkedContextMetadata] = None
|
||||||
sin: torch.Tensor = None
|
sin: torch.Tensor = None
|
||||||
cos: torch.Tensor = None
|
cos: torch.Tensor = None
|
||||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||||
@@ -262,21 +240,21 @@ class AscendMLAMetadataBuilder:
|
|||||||
self.cos_cache = None
|
self.cos_cache = None
|
||||||
self.sin_cache = None
|
self.sin_cache = None
|
||||||
|
|
||||||
self.pcp_size = get_pcp_group().world_size
|
self.chunk_seq_lens: torch.Tensor = None
|
||||||
self.pcp_rank = get_pcp_group(
|
self.cu_seq_lens_cpu: torch.Tensor = None
|
||||||
).rank_in_group if self.pcp_size > 1 else 0
|
self.num_chunks: torch.Tensor = None
|
||||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
self.max_context_chunk = 0
|
||||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
self.num_decodes = 0
|
||||||
) if self.dcp_size > 1 else 0
|
self.num_prefills = 0
|
||||||
self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
|
self.num_decode_tokens = 0
|
||||||
self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size
|
self.num_prefill_tokens = 0
|
||||||
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
|
self.context_lens_cpu: torch.Tensor = None
|
||||||
0)
|
self.num_actual_tokens: Optional[int] = None
|
||||||
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
|
self.block_table: torch.Tensor = None
|
||||||
self.batch_seq_mask_buf = torch.empty(max_num_seqs *
|
self.slot_mapping: torch.Tensor = None
|
||||||
self.decode_threshold,
|
self.graph_pad_size = 0
|
||||||
dtype=torch.uint8,
|
self.query_lens: torch.Tensor = None
|
||||||
device=device)
|
self.seq_lens: torch.Tensor = None
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "NPUInputBatch",
|
def reorder_batch(self, input_batch: "NPUInputBatch",
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
@@ -396,6 +374,12 @@ class AscendMLAMetadataBuilder:
|
|||||||
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
|
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
|
||||||
return actual_seq_lengths_q
|
return actual_seq_lengths_q
|
||||||
|
|
||||||
|
def set_num_actual_tokens(
|
||||||
|
self,
|
||||||
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
):
|
||||||
|
self.num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
@@ -403,41 +387,18 @@ class AscendMLAMetadataBuilder:
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
) -> AscendMLAMetadata:
|
) -> AscendMLAMetadata:
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
|
||||||
|
|
||||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
|
self.num_decodes, self.num_prefills, self.num_decode_tokens, self.num_prefill_tokens = \
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
|
||||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||||
assert num_decodes + num_prefills == num_reqs
|
self.set_num_actual_tokens(common_attn_metadata)
|
||||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
assert self.num_decodes + self.num_prefills == num_reqs
|
||||||
|
assert self.num_decode_tokens + self.num_prefill_tokens == common_attn_metadata.num_actual_tokens
|
||||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
|
||||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
|
||||||
# it blocks on all previous kernels.
|
|
||||||
device = self.device
|
|
||||||
|
|
||||||
# If graph_pad_size > -1, mean is running in fullgraph mode.
|
|
||||||
graph_pad_size = common_attn_metadata.graph_pad_size
|
|
||||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
|
||||||
if graph_pad_size > num_reqs and self.speculative_config.disable_padded_drafter_batch:
|
|
||||||
block_table = (
|
|
||||||
common_attn_metadata.block_table_tensor[:graph_pad_size])
|
|
||||||
else:
|
|
||||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
|
||||||
|
|
||||||
if num_actual_tokens_pcp_padded is None:
|
|
||||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
|
||||||
|
|
||||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
self.slot_mapping = common_attn_metadata.slot_mapping[:self.
|
||||||
num_actual_tokens_pcp_padded]
|
num_actual_tokens]
|
||||||
input_positions = common_attn_metadata.positions[:
|
|
||||||
num_actual_tokens_pcp_padded].long(
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cos_cache is None:
|
if self.cos_cache is None:
|
||||||
self.cos_cache = model.model.layers[
|
self.cos_cache = model.model.layers[
|
||||||
@@ -451,59 +412,132 @@ class AscendMLAMetadataBuilder:
|
|||||||
self.model_config.dtype) # type: ignore
|
self.model_config.dtype) # type: ignore
|
||||||
|
|
||||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||||
query_lens = query_seq_lens_cpu[:num_reqs]
|
self.query_lens = query_seq_lens_cpu[:num_reqs]
|
||||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
|
||||||
|
self.set_prefill_block_table(common_attn_metadata)
|
||||||
|
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
chunked_context_metadata = None
|
if self.num_prefills > 0:
|
||||||
if num_prefills > 0:
|
prefill_metadata = self.build_prefill_metadata(
|
||||||
pcp_metadata = None
|
common_prefix_len, common_attn_metadata, model)
|
||||||
|
|
||||||
reqs_start = num_decodes # prefill_start
|
decode_metadata = None
|
||||||
tokens_start = num_decode_tokens
|
if self.num_decodes > 0:
|
||||||
max_query_len = query_lens[reqs_start:].max().item()
|
decode_metadata = self.build_decode_metadata(
|
||||||
max_seq_lens = seq_lens[reqs_start:].max().item()
|
common_prefix_len, common_attn_metadata, model)
|
||||||
|
|
||||||
|
return self.metadata_cls( # type: ignore
|
||||||
|
num_actual_tokens_pcp_padded=self.num_actual_tokens,
|
||||||
|
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||||
|
num_actual_tokens=self.num_actual_tokens,
|
||||||
|
query_lens=self.query_lens.tolist(),
|
||||||
|
slot_mapping=self.slot_mapping,
|
||||||
|
head_dim=self.model_config.get_head_size(),
|
||||||
|
num_decodes=self.num_decodes,
|
||||||
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
|
num_prefills=self.num_prefills,
|
||||||
|
attn_mask=common_attn_metadata.attn_mask,
|
||||||
|
attn_state=common_attn_metadata.attn_state,
|
||||||
|
prefill=prefill_metadata,
|
||||||
|
decode=decode_metadata,
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
|
block_tables=self.block_table,
|
||||||
|
seq_lens=self.seq_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_chunked_metadata(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
model: nn.Module,
|
||||||
|
):
|
||||||
|
if not self.chunked_prefill_enabled:
|
||||||
|
return None
|
||||||
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
|
|
||||||
|
num_computed_tokens_cpu = (self.seq_lens - self.query_lens)
|
||||||
|
reqs_start = self.num_decodes # prefill_start
|
||||||
|
|
||||||
|
self.context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||||
|
max_context_len_cpu = self.context_lens_cpu.max().item()
|
||||||
|
if not max_context_len_cpu > 0:
|
||||||
|
return None
|
||||||
|
num_prefills_with_context_cpu = (self.context_lens_cpu
|
||||||
|
> 0).sum().item()
|
||||||
|
self.max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||||
|
num_prefills_with_context_cpu)
|
||||||
|
self.max_context_chunk = round_down(self.max_context_chunk,
|
||||||
|
self.block_size)
|
||||||
|
|
||||||
|
assert self.max_context_chunk > 0
|
||||||
|
self.num_chunks = cdiv(max_context_len_cpu, self.max_context_chunk)
|
||||||
|
chunk_starts = torch.arange(self.num_chunks, dtype=torch.int32) \
|
||||||
|
.unsqueeze(1).expand(-1, self.num_prefills) * self.max_context_chunk
|
||||||
|
chunk_ends = torch.min(self.context_lens_cpu.unsqueeze(0),
|
||||||
|
chunk_starts + self.max_context_chunk)
|
||||||
|
self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||||
|
self.cu_seq_lens_cpu = torch.zeros(self.num_chunks,
|
||||||
|
self.num_prefills + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=True)
|
||||||
|
torch.cumsum(self.chunk_seq_lens,
|
||||||
|
dim=1,
|
||||||
|
out=self.cu_seq_lens_cpu[:, 1:],
|
||||||
|
dtype=torch.int32)
|
||||||
|
return ChunkedContextMetadata(
|
||||||
|
cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(
|
||||||
|
self.device, non_blocking=True),
|
||||||
|
starts=chunk_starts.pin_memory().to(self.device,
|
||||||
|
non_blocking=True),
|
||||||
|
seq_tot=self.chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
|
max_seq_lens=self.chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
|
chunk_seq_lens=self.chunk_seq_lens,
|
||||||
|
chunk_seq_lens_npu=self.chunk_seq_lens.npu(),
|
||||||
|
workspace=self.chunked_prefill_workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_prefill_block_table(
|
||||||
|
self,
|
||||||
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
):
|
||||||
|
# If graph_pad_size > -1, mean is running in fullgraph mode.
|
||||||
|
self.graph_pad_size = common_attn_metadata.graph_pad_size
|
||||||
|
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||||
|
if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch:
|
||||||
|
self.block_table = (
|
||||||
|
common_attn_metadata.block_table_tensor[:self.graph_pad_size])
|
||||||
|
else:
|
||||||
|
self.block_table = (
|
||||||
|
common_attn_metadata.block_table_tensor[:common_attn_metadata.
|
||||||
|
num_reqs])
|
||||||
|
|
||||||
|
def set_decode_block_table(
|
||||||
|
self, common_attn_metadata: AscendCommonAttentionMetadata):
|
||||||
|
self.block_table = self.block_table[:self.num_decodes, ...]
|
||||||
|
|
||||||
|
def build_prefill_metadata(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
model: nn.Module,
|
||||||
|
) -> AscendMLAPrefillMetadata:
|
||||||
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
|
|
||||||
|
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||||
|
input_positions = common_attn_metadata.positions[:self.
|
||||||
|
num_actual_tokens].long(
|
||||||
|
)
|
||||||
|
|
||||||
|
chunked_context_metadata = self.build_chunked_metadata(
|
||||||
|
common_prefix_len, common_attn_metadata, model)
|
||||||
|
reqs_start = self.num_decodes # prefill_start
|
||||||
|
tokens_start = self.num_decode_tokens
|
||||||
|
max_query_len = self.query_lens[reqs_start:].max().item()
|
||||||
|
max_seq_lens = self.seq_lens[reqs_start:].max().item()
|
||||||
prefill_query_start_loc = query_start_loc[
|
prefill_query_start_loc = query_start_loc[
|
||||||
reqs_start:] - query_start_loc[reqs_start]
|
reqs_start:] - query_start_loc[reqs_start]
|
||||||
|
|
||||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
|
||||||
max_context_len_cpu = context_lens_cpu.max().item()
|
|
||||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
|
||||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
|
||||||
max_context_chunk = (self.chunked_prefill_workspace_size //
|
|
||||||
num_prefills_with_context_cpu)
|
|
||||||
max_context_chunk = round_down(max_context_chunk,
|
|
||||||
self.block_size)
|
|
||||||
|
|
||||||
assert max_context_chunk > 0
|
|
||||||
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
|
||||||
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
|
||||||
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
|
|
||||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
|
||||||
chunk_starts + max_context_chunk)
|
|
||||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
|
||||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
|
||||||
num_prefills + 1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
pin_memory=True)
|
|
||||||
torch.cumsum(chunk_seq_lens,
|
|
||||||
dim=1,
|
|
||||||
out=cu_seq_lens_cpu[:, 1:],
|
|
||||||
dtype=torch.int32)
|
|
||||||
|
|
||||||
chunked_context_metadata = (
|
|
||||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
|
||||||
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
|
|
||||||
device, non_blocking=True),
|
|
||||||
starts=chunk_starts.pin_memory().to(device,
|
|
||||||
non_blocking=True),
|
|
||||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
|
||||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
|
||||||
chunk_seq_lens=chunk_seq_lens,
|
|
||||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
|
||||||
workspace=self.chunked_prefill_workspace,
|
|
||||||
))
|
|
||||||
prefill_input_positions = input_positions[tokens_start:]
|
prefill_input_positions = input_positions[tokens_start:]
|
||||||
cos = self.cos_cache[
|
cos = self.cos_cache[
|
||||||
prefill_input_positions].unsqueeze( # type: ignore
|
prefill_input_positions].unsqueeze( # type: ignore
|
||||||
@@ -511,79 +545,93 @@ class AscendMLAMetadataBuilder:
|
|||||||
sin = self.sin_cache[
|
sin = self.sin_cache[
|
||||||
prefill_input_positions].unsqueeze( # type: ignore
|
prefill_input_positions].unsqueeze( # type: ignore
|
||||||
1).unsqueeze(2)
|
1).unsqueeze(2)
|
||||||
prefill_metadata = AscendMLAPrefillMetadata(
|
return AscendMLAPrefillMetadata(
|
||||||
attn_mask=common_attn_metadata.attn_mask,
|
attn_mask=common_attn_metadata.attn_mask,
|
||||||
query_lens=query_lens[reqs_start:].to(torch.int32),
|
query_lens=self.query_lens[reqs_start:].to(torch.int32),
|
||||||
seq_lens=seq_lens,
|
seq_lens=self.seq_lens,
|
||||||
context_lens=seq_lens[reqs_start:],
|
context_lens=self.seq_lens[reqs_start:],
|
||||||
input_positions=prefill_input_positions,
|
input_positions=prefill_input_positions,
|
||||||
block_table=block_table[reqs_start:, ...],
|
block_table=self.block_table[reqs_start:, ...],
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
max_seq_lens=max_seq_lens,
|
max_seq_lens=max_seq_lens,
|
||||||
query_start_loc=prefill_query_start_loc,
|
query_start_loc=prefill_query_start_loc,
|
||||||
chunked_context=chunked_context_metadata,
|
chunked_context=chunked_context_metadata,
|
||||||
sin=sin,
|
sin=sin,
|
||||||
cos=cos,
|
cos=cos,
|
||||||
pcp_metadata=pcp_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
decode_metadata = None
|
def build_decode_metadata(
|
||||||
if num_decodes > 0:
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
model: nn.Module,
|
||||||
|
) -> AscendMLADecodeMetadata:
|
||||||
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
|
|
||||||
|
input_positions = common_attn_metadata.positions[:self.
|
||||||
|
num_actual_tokens].long(
|
||||||
|
)
|
||||||
|
|
||||||
cos, sin = get_cos_and_sin_mla()
|
cos, sin = get_cos_and_sin_mla()
|
||||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||||
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
|
actual_seq_lengths_q = query_start_loc_cpu[1:self.num_decodes +
|
||||||
1].tolist()
|
1].tolist()
|
||||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
max_seq_lens = self.seq_lens[:self.num_decodes].max().item()
|
||||||
seq_lens = seq_lens[:num_decodes]
|
self.seq_lens = self.seq_lens[:self.num_decodes]
|
||||||
input_positions = input_positions[:num_decode_tokens]
|
input_positions = input_positions[:self.num_decode_tokens]
|
||||||
block_table = block_table[:num_decodes, ...]
|
|
||||||
|
self.set_decode_block_table(common_attn_metadata)
|
||||||
|
|
||||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||||
if graph_pad_size > num_decodes and \
|
if self.graph_pad_size > self.num_decodes and \
|
||||||
self.speculative_config.disable_padded_drafter_batch:
|
self.speculative_config.disable_padded_drafter_batch:
|
||||||
block_table = block_table[:graph_pad_size, ...]
|
self.block_table = self.block_table[:self.graph_pad_size, ...]
|
||||||
seq_lens_list = seq_lens.tolist()
|
seq_lens_list = self.seq_lens.tolist()
|
||||||
|
|
||||||
cp_seq_len, batch_seq_mask = None, None
|
cp_seq_len, batch_seq_mask = None, None
|
||||||
|
|
||||||
if graph_pad_size > num_reqs:
|
if self.graph_pad_size > num_reqs:
|
||||||
if self.speculative_config.disable_padded_drafter_batch:
|
if self.speculative_config.disable_padded_drafter_batch:
|
||||||
num_reqs_pad_size = graph_pad_size - num_reqs
|
num_reqs_pad_size = self.graph_pad_size - num_reqs
|
||||||
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
|
actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad(
|
||||||
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
|
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
|
||||||
seq_lens_list = seq_lens_list + [0] * (graph_pad_size -
|
seq_lens_list = seq_lens_list + [0] * (self.graph_pad_size -
|
||||||
num_decodes)
|
self.num_decodes)
|
||||||
num_block_pad_size = graph_pad_size - block_table.shape[0]
|
num_block_pad_size = self.graph_pad_size - self.block_table.shape[
|
||||||
|
0]
|
||||||
if num_block_pad_size > 0:
|
if num_block_pad_size > 0:
|
||||||
block_table_padding = torch.zeros(
|
block_table_padding = torch.zeros(
|
||||||
(num_block_pad_size, ) + block_table.shape[1:],
|
(num_block_pad_size, ) + self.block_table.shape[1:],
|
||||||
dtype=block_table.dtype,
|
dtype=self.block_table.dtype,
|
||||||
device=block_table.device)
|
device=self.block_table.device)
|
||||||
block_table = torch.cat(
|
self.block_table = torch.cat(
|
||||||
[block_table, block_table_padding], dim=0)
|
[self.block_table, block_table_padding], dim=0)
|
||||||
else:
|
else:
|
||||||
num_token_pad_size = graph_pad_size - num_decode_tokens
|
num_token_pad_size = self.graph_pad_size - self.num_decode_tokens
|
||||||
num_reqs_pad_size = (
|
num_reqs_pad_size = (
|
||||||
graph_pad_size //
|
self.graph_pad_size //
|
||||||
common_attn_metadata.decode_token_per_req - num_reqs)
|
common_attn_metadata.decode_token_per_req - num_reqs)
|
||||||
num_block_table_pad_size = (
|
num_block_table_pad_size = (
|
||||||
graph_pad_size //
|
self.graph_pad_size //
|
||||||
common_attn_metadata.decode_token_per_req -
|
common_attn_metadata.decode_token_per_req -
|
||||||
num_decodes)
|
self.num_decodes)
|
||||||
seq_lens_list = seq_lens.tolist() + [0] * num_reqs_pad_size
|
seq_lens_list = self.seq_lens.tolist() + [0
|
||||||
|
] * num_reqs_pad_size
|
||||||
slot_padding = torch.full((num_token_pad_size, ),
|
slot_padding = torch.full((num_token_pad_size, ),
|
||||||
PAD_SLOT_ID,
|
PAD_SLOT_ID,
|
||||||
dtype=slot_mapping.dtype,
|
dtype=self.slot_mapping.dtype,
|
||||||
device=slot_mapping.device)
|
device=self.slot_mapping.device)
|
||||||
slot_mapping = torch.cat([slot_mapping, slot_padding])
|
self.slot_mapping = torch.cat(
|
||||||
|
[self.slot_mapping, slot_padding])
|
||||||
block_table_padding = torch.zeros(
|
block_table_padding = torch.zeros(
|
||||||
(num_block_table_pad_size, ) + block_table.shape[1:],
|
(num_block_table_pad_size, ) + self.block_table.shape[1:],
|
||||||
dtype=block_table.dtype,
|
dtype=self.block_table.dtype,
|
||||||
device=block_table.device)
|
device=self.block_table.device)
|
||||||
block_table = torch.cat([block_table, block_table_padding],
|
self.block_table = torch.cat(
|
||||||
dim=0)
|
[self.block_table, block_table_padding], dim=0)
|
||||||
position_padding = torch.zeros(
|
position_padding = torch.zeros(num_token_pad_size,
|
||||||
num_token_pad_size,
|
|
||||||
dtype=input_positions.dtype,
|
dtype=input_positions.dtype,
|
||||||
device=input_positions.device)
|
device=input_positions.device)
|
||||||
input_positions = torch.cat(
|
input_positions = torch.cat(
|
||||||
@@ -596,17 +644,15 @@ class AscendMLAMetadataBuilder:
|
|||||||
assert self.cos_cache is not None
|
assert self.cos_cache is not None
|
||||||
assert self.sin_cache is not None
|
assert self.sin_cache is not None
|
||||||
if cos is None and sin is None:
|
if cos is None and sin is None:
|
||||||
cos = self.cos_cache[
|
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||||
input_positions].unsqueeze( # type: ignore
|
|
||||||
1).unsqueeze(2)
|
1).unsqueeze(2)
|
||||||
sin = self.sin_cache[
|
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||||
input_positions].unsqueeze( # type: ignore
|
|
||||||
1).unsqueeze(2)
|
1).unsqueeze(2)
|
||||||
|
|
||||||
decode_metadata = AscendMLADecodeMetadata(
|
decode_metadata = AscendMLADecodeMetadata(
|
||||||
input_positions=input_positions,
|
input_positions=input_positions,
|
||||||
block_table=block_table,
|
block_table=self.block_table,
|
||||||
seq_lens=seq_lens,
|
seq_lens=self.seq_lens,
|
||||||
seq_lens_list=seq_lens_list,
|
seq_lens_list=seq_lens_list,
|
||||||
max_seq_lens=max_seq_lens,
|
max_seq_lens=max_seq_lens,
|
||||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||||
@@ -616,44 +662,26 @@ class AscendMLAMetadataBuilder:
|
|||||||
cp_seq_len=cp_seq_len,
|
cp_seq_len=cp_seq_len,
|
||||||
batch_seq_mask=batch_seq_mask)
|
batch_seq_mask=batch_seq_mask)
|
||||||
else:
|
else:
|
||||||
cos[:num_decode_tokens,
|
cos[:self.num_decode_tokens,
|
||||||
...] = self.cos_cache[input_positions].unsqueeze(
|
...] = self.cos_cache[input_positions].unsqueeze(1).unsqueeze(
|
||||||
1).unsqueeze(2)
|
2)
|
||||||
sin[:num_decode_tokens,
|
sin[:self.num_decode_tokens,
|
||||||
...] = self.sin_cache[input_positions].unsqueeze(
|
...] = self.sin_cache[input_positions].unsqueeze(1).unsqueeze(
|
||||||
1).unsqueeze(2)
|
2)
|
||||||
|
|
||||||
decode_metadata = AscendMLADecodeMetadata(
|
decode_metadata = AscendMLADecodeMetadata(
|
||||||
input_positions=input_positions,
|
input_positions=input_positions,
|
||||||
block_table=block_table,
|
block_table=self.block_table,
|
||||||
seq_lens=seq_lens,
|
seq_lens=self.seq_lens,
|
||||||
seq_lens_list=seq_lens_list,
|
seq_lens_list=seq_lens_list,
|
||||||
max_seq_lens=max_seq_lens,
|
max_seq_lens=max_seq_lens,
|
||||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||||
sin=sin[:num_decode_tokens, ...],
|
sin=sin[:self.num_decode_tokens, ...],
|
||||||
cos=cos[:num_decode_tokens, ...],
|
cos=cos[:self.num_decode_tokens, ...],
|
||||||
cp_seq_len=cp_seq_len,
|
cp_seq_len=cp_seq_len,
|
||||||
batch_seq_mask=batch_seq_mask)
|
batch_seq_mask=batch_seq_mask)
|
||||||
|
return decode_metadata
|
||||||
return self.metadata_cls( # type: ignore
|
|
||||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
|
||||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
|
||||||
num_actual_tokens=num_actual_tokens,
|
|
||||||
query_lens=query_lens.tolist(),
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
head_dim=self.model_config.get_head_size(),
|
|
||||||
num_decodes=num_decodes,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
attn_mask=common_attn_metadata.attn_mask,
|
|
||||||
attn_state=common_attn_metadata.attn_state,
|
|
||||||
prefill=prefill_metadata,
|
|
||||||
decode=decode_metadata,
|
|
||||||
query_start_loc=query_start_loc,
|
|
||||||
block_tables=block_table,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_for_graph_capture(
|
def build_for_graph_capture(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class AscendPrefillContextParallelMetadata:
|
|||||||
|
|
||||||
cp_kv_recover_idx_for_chunk: torch.Tensor = None
|
cp_kv_recover_idx_for_chunk: torch.Tensor = None
|
||||||
|
|
||||||
num_actual_tokens_pcp_padded: Optional[int] = None
|
num_actual_tokens_pcp_padded: int = 0
|
||||||
|
|
||||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user