### What this PR does / why we need it? This PR adds DCP support to the SFA backend. Please note that due to operator constraints, the current implementation has to all-gather the entire KV cache and modify the block table to satisfy the operator input requirements. This results in significantly increased communication overhead and peak memory usage. Therefore, this is only a temporary workaround and will be refactored once the operator provides proper support. Additionally, because of the above limitations, `cp_kv_cache_interleave_size` is currently required to be equal to `block_size`. This restriction will also be removed after the refactor. #### Test accuracy test using DeepSeek-V3.2-Exp-W8A8 with dp2tp8dcp8 | dataset | version | metric | mode | vllm-api-general-stream | |----- | ----- | ----- | ----- | -----| | gsm8kdataset | - | accuracy | gen | 96.35 | - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
282 lines
12 KiB
Python
282 lines
12 KiB
Python
import sys
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
|
|
from tests.ut.attention.utils import patch_distributed_groups
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
|
from vllm.distributed.parallel_state import GroupCoordinator
|
|
|
|
if 'torch_npu._inductor' not in sys.modules:
|
|
sys.modules['torch_npu._inductor'] = MagicMock()
|
|
|
|
from vllm_ascend.attention.sfa_v1 import (AscendSFABackend, AscendSFAImpl,
|
|
AscendSFAMetadata,
|
|
AscendSFAMetadataBuilder)
|
|
from vllm_ascend.utils import enable_dsa_cp, vllm_version_is
|
|
|
|
|
|
class TestAscendSFABackend(TestBase):
|
|
|
|
def test_get_name(self):
|
|
self.assertEqual(AscendSFABackend.get_name(), "ASCEND_SFA")
|
|
|
|
def test_get_builder_cls(self):
|
|
self.assertEqual(AscendSFABackend.get_builder_cls(),
|
|
AscendSFAMetadataBuilder)
|
|
|
|
def test_get_kv_cache_shape(self):
|
|
result = AscendSFABackend.get_kv_cache_shape(2, 4, 8, 128)
|
|
self.assertEqual(result, (2, 4, 8, 128))
|
|
|
|
def test_get_impl_cls(self):
|
|
result = AscendSFABackend.get_impl_cls()
|
|
self.assertEqual(result, AscendSFAImpl)
|
|
|
|
|
|
class TestAscendSFAMetadata(TestBase):
|
|
|
|
def test_ascend_sfa_metadata_default(self):
|
|
num_actual_tokens = 100
|
|
slot_mapping = torch.randn(100, 4, 1024)
|
|
seq_lens = torch.tensor([30, 50])
|
|
cum_query_lens = torch.tensor([0, 30, 80])
|
|
block_table = torch.randint(0, 100, (100, 4))
|
|
|
|
rope_dim = 32
|
|
max_seq_len = int(seq_lens.max().item())
|
|
sin = torch.randn(max_seq_len, rope_dim)
|
|
cos = torch.randn(max_seq_len, rope_dim)
|
|
|
|
num_input_tokens = 2
|
|
head_dim = None
|
|
attn_mask = None
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
|
|
metadata = AscendSFAMetadata(
|
|
num_actual_tokens=num_actual_tokens,
|
|
slot_mapping=slot_mapping,
|
|
seq_lens=seq_lens,
|
|
cum_query_lens=cum_query_lens,
|
|
block_table=block_table,
|
|
sin=sin,
|
|
cos=cos,
|
|
num_input_tokens=num_input_tokens,
|
|
head_dim=head_dim,
|
|
attn_mask=attn_mask,
|
|
attn_state=attn_state,
|
|
)
|
|
|
|
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
|
|
self.assertIs(metadata.slot_mapping, slot_mapping)
|
|
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))
|
|
self.assertTrue(torch.equal(metadata.cum_query_lens, cum_query_lens))
|
|
self.assertIs(metadata.block_table, block_table)
|
|
self.assertIs(metadata.sin, sin)
|
|
self.assertIs(metadata.cos, cos)
|
|
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
|
|
self.assertIs(metadata.head_dim, head_dim)
|
|
self.assertIs(metadata.attn_mask, attn_mask)
|
|
self.assertEqual(metadata.attn_state, attn_state)
|
|
|
|
|
|
class TestAscendSFAMetadataBuilder(TestBase):
|
|
|
|
@patch('vllm.distributed.parallel_state._TP',
|
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
|
def setUp(self, mock_tp):
|
|
mock_tp.world_size = 2
|
|
mock_tp.rank_in_group = MagicMock()
|
|
mock_tp.device_group = MagicMock()
|
|
|
|
self.mock_cfg = MagicMock()
|
|
|
|
self.mock_cfg.parallel_config = MagicMock()
|
|
self.mock_cfg.parallel_config.tensor_parallel_size = 1
|
|
self.mock_cfg.parallel_config.prefill_context_parallel_size = 1
|
|
self.mock_cfg.parallel_config.decode_context_parallel_size = 1
|
|
|
|
self.mock_cfg.compilation_config = MagicMock()
|
|
self.mock_cfg.compilation_config.pass_config = MagicMock()
|
|
self.mock_cfg.compilation_config.pass_config.enable_sp = False
|
|
|
|
self.mock_cfg.speculative_config.num_speculative_tokens = 0
|
|
|
|
self.patcher = patch("vllm.config.get_current_vllm_config",
|
|
return_value=self.mock_cfg)
|
|
self.patcher.start()
|
|
|
|
# Mock parent class __init__ to avoid complex initialization,
|
|
# but still set the essential attributes that child class needs
|
|
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
|
|
device, metadata_cls, supports_dcp_with_varlen):
|
|
self.metadata_cls = metadata_cls
|
|
self.kv_cache_spec = kv_cache_spec
|
|
self.model_config = vllm_config.model_config
|
|
self.vllm_config = vllm_config
|
|
self.device = device
|
|
self.chunked_prefill_workspace_size = 128 * 1024
|
|
self.chunked_prefill_workspace = torch.empty(
|
|
(self.chunked_prefill_workspace_size,
|
|
vllm_config.model_config.get_head_size()),
|
|
dtype=vllm_config.model_config.dtype,
|
|
device=device,
|
|
)
|
|
|
|
self.parent_init_patcher = patch(
|
|
"vllm.model_executor.layers.attention.mla_attention.MLACommonMetadataBuilder.__init__",
|
|
mock_parent_init)
|
|
self.parent_init_patcher.start()
|
|
|
|
if hasattr(enable_dsa_cp, "cache_clear"):
|
|
enable_dsa_cp.cache_clear()
|
|
|
|
def tearDown(self):
|
|
self.patcher.stop()
|
|
self.parent_init_patcher.stop()
|
|
|
|
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
|
def test_ascend_sfa_metadata_builder_default(self):
|
|
kv_cache_spec = MagicMock()
|
|
layer_names = ["layer1", "layer2"]
|
|
vllm_config = MagicMock()
|
|
vllm_config.cache_config.block_size = 16
|
|
vllm_config.model_config.max_model_len = 1024
|
|
vllm_config.model_config.get_head_size.return_value = 64
|
|
vllm_config.model_config.dtype = torch.float16
|
|
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
|
speculative_config = MagicMock()
|
|
speculative_config.num_speculative_tokens = 4
|
|
vllm_config.speculative_config = speculative_config
|
|
device = torch.device("cpu")
|
|
|
|
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
|
|
layer_names=layer_names,
|
|
vllm_config=vllm_config,
|
|
device=device)
|
|
|
|
assert builder.device == device
|
|
assert builder.vllm_config == vllm_config
|
|
|
|
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
|
|
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
|
@patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp")
|
|
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
|
def test_ascend_sfa_metadata_builder_build(
|
|
self,
|
|
mock_enable_dsa_cp,
|
|
mock_get_cos_and_sin_mla,
|
|
mock_get_current_vllm_config,
|
|
):
|
|
mock_enable_dsa_cp.return_value = False
|
|
|
|
cfg = MagicMock()
|
|
cfg.model_config = MagicMock()
|
|
cfg.model_config.hf_text_config = MagicMock()
|
|
|
|
mock_get_current_vllm_config.return_value = cfg
|
|
kv_cache_spec = MagicMock()
|
|
layer_names = ["layer1", "layer2"]
|
|
vllm_config = MagicMock()
|
|
vllm_config.cache_config.block_size = 16
|
|
vllm_config.model_config.max_model_len = 1024
|
|
vllm_config.model_config.get_head_size.return_value = 64
|
|
vllm_config.model_config.dtype = torch.float16
|
|
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
|
speculative_config = MagicMock()
|
|
speculative_config.num_speculative_tokens = 4
|
|
vllm_config.speculative_config = speculative_config
|
|
device = torch.device("cpu")
|
|
|
|
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
|
|
layer_names=layer_names,
|
|
vllm_config=vllm_config,
|
|
device=device)
|
|
|
|
common_attn_metadata = MagicMock()
|
|
common_attn_metadata.num_reqs = 10
|
|
common_attn_metadata.num_actual_tokens = 100
|
|
common_attn_metadata.query_start_loc = torch.tensor(
|
|
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
|
|
common_attn_metadata.query_start_loc_cpu = torch.tensor(
|
|
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
|
|
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
|
|
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
|
|
common_attn_metadata.positions = torch.randn(100)
|
|
common_attn_metadata.attn_mask = None
|
|
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
|
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
|
|
common_attn_metadata.cos = None
|
|
common_attn_metadata.sin = None
|
|
common_attn_metadata.num_input_tokens = 100
|
|
|
|
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
|
|
torch.randn(100))
|
|
|
|
metadata = builder.build(
|
|
common_prefix_len=10,
|
|
common_attn_metadata=common_attn_metadata,
|
|
)
|
|
|
|
assert isinstance(metadata, AscendSFAMetadata)
|
|
assert metadata.num_actual_tokens == common_attn_metadata.num_actual_tokens
|
|
assert metadata.slot_mapping.shape == (100, 4, 1024)
|
|
|
|
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
|
|
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
|
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
|
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
|
|
self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config):
|
|
cfg = MagicMock()
|
|
cfg.model_config = MagicMock()
|
|
cfg.model_config.hf_text_config = MagicMock()
|
|
|
|
mock_get_current_vllm_config.return_value = cfg
|
|
|
|
kv_cache_spec = MagicMock()
|
|
layer_names = ["layer1", "layer2"]
|
|
vllm_config = MagicMock()
|
|
vllm_config.cache_config.block_size = 16
|
|
vllm_config.model_config.max_model_len = 1024
|
|
vllm_config.model_config.get_head_size.return_value = 64
|
|
vllm_config.model_config.dtype = torch.float16
|
|
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
|
speculative_config = MagicMock()
|
|
speculative_config.num_speculative_tokens = 4
|
|
vllm_config.speculative_config = speculative_config
|
|
device = torch.device("cpu")
|
|
|
|
builder = AscendSFAMetadataBuilder(kv_cache_spec=kv_cache_spec,
|
|
layer_names=layer_names,
|
|
vllm_config=vllm_config,
|
|
device=device)
|
|
|
|
common_attn_metadata = MagicMock()
|
|
common_attn_metadata.num_reqs = 10
|
|
common_attn_metadata.num_actual_tokens = 100
|
|
common_attn_metadata.query_start_loc = torch.tensor(
|
|
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
|
|
common_attn_metadata.query_start_loc_cpu = torch.tensor(
|
|
[0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
|
|
common_attn_metadata.slot_mapping = torch.randn(100, 4, 1024)
|
|
common_attn_metadata.seq_lens_cpu = torch.tensor([2] * 10)
|
|
common_attn_metadata.positions = torch.randn(100)
|
|
common_attn_metadata.attn_mask = None
|
|
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
|
common_attn_metadata.block_table_tensor = torch.randn(100, 4)
|
|
common_attn_metadata.cos = None
|
|
common_attn_metadata.sin = None
|
|
common_attn_metadata.num_input_tokens = 100
|
|
|
|
mock_get_cos_and_sin_mla.return_value = (torch.randn(100),
|
|
torch.randn(100))
|
|
|
|
attn_metadata = builder.build_for_graph_capture(
|
|
common_attn_metadata=common_attn_metadata,
|
|
attn_state=AscendAttentionState.DecodeOnly,
|
|
)
|
|
|
|
assert isinstance(attn_metadata, AscendSFAMetadata)
|
|
assert attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|