[Feat](sfa,dcp) support dcp for sfa (#6563)

### What this PR does / why we need it?
This PR adds DCP support to the SFA backend.

Please note that due to operator constraints, the current implementation
has to all-gather the entire KV cache and modify the block table to
satisfy the operator input requirements. This results in significantly
increased communication overhead and peak memory usage. Therefore, this
is only a temporary workaround and will be refactored once the operator
provides proper support.

Additionally, because of the above limitations,
`cp_kv_cache_interleave_size` is currently required to be equal to
`block_size`. This restriction will also be removed after the refactor.

#### Test
accuracy test using DeepSeek-V3.2-Exp-W8A8 with dp2tp8dcp8

| dataset | version | metric | mode | vllm-api-general-stream |
|----- | ----- | ----- | ----- | -----|
| gsm8kdataset | - | accuracy | gen | 96.35 |

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu
2026-02-09 18:52:25 +08:00
committed by GitHub
parent 80e5812b39
commit cb7c419bc0
5 changed files with 190 additions and 13 deletions

View File

@@ -77,6 +77,9 @@ jobs:
- name: multi-node-qwenw8a8-2node-longseq
config_file_path: Qwen3-235B-W8A8-longseq.yaml
size: 2
- name: multi-node-deepseek-V3_2-W8A8-cp
config_file_path: DeepSeek-V3_2-W8A8-cp.yaml
size: 2
- name: multi-node-qwen-disagg-pd
config_file_path: Qwen3-235B-disagg-pd.yaml
size: 2

View File

@@ -0,0 +1,91 @@
test_name: "test DeepSeek-V3.2-W8A8 for PCP&DCP"
model: "vllm-ascend/DeepSeek-V3.2-W8A8"
num_nodes: 2
npu_per_node: 16
env_common:
HCCL_OP_EXPANSION_MODE: "AIV"
VLLM_USE_MODELSCOPE: true
HCCL_BUFFSIZE: 1024
SERVER_PORT: 8080
OMP_PROC_BIND: false
OMP_NUM_THREADS: 1
PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True"
VLLM_ASCEND_ENABLE_FLASHCOMM1: 1
ASCEND_A3_EBA_ENABLE: 1
deployment:
-
server_cmd: >
vllm serve vllm-ascend/DeepSeek-V3.2-W8A8
--host 0.0.0.0
--port $SERVER_PORT
--data-parallel-size 4
--data-parallel-size-local 2
--data-parallel-address $LOCAL_IP
--data-parallel-rpc-port 13399
--tensor-parallel-size 8
--decode-context-parallel-size 8
--quantization ascend
--seed 1024
--enable-expert-parallel
--max-num-seqs 16
--max-model-len 8192
--max-num-batched-tokens 4096
--no-enable-prefix-caching
--gpu-memory-utilization 0.85
--trust-remote-code
--speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}'
--compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}'
--additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}'
--tokenizer-mode deepseek_v32
--reasoning-parser deepseek_v3
-
server_cmd: >
vllm serve vllm-ascend/DeepSeek-V3.2-W8A8
--headless
--data-parallel-size 4
--data-parallel-rpc-port 13399
--data-parallel-size-local 2
--data-parallel-start-rank 2
--data-parallel-address $MASTER_IP
--tensor-parallel-size 8
--decode-context-parallel-size 8
--quantization ascend
--seed 1024
--enable-expert-parallel
--max-num-seqs 16
--max-model-len 8192
--max-num-batched-tokens 4096
--no-enable-prefix-caching
--gpu-memory-utilization 0.85
--trust-remote-code
--speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}'
--compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}'
--additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}'
--tokenizer-mode deepseek_v32
--reasoning-parser deepseek_v3
benchmarks:
perf:
case_type: performance
dataset_path: vllm-ascend/GSM8K-in3500-bs2800
request_conf: vllm_api_stream_chat
dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_str_perf
num_prompts: 512
max_out_len: 3000
batch_size: 512
request_rate: 11.2
baseline: 1253.8466
threshold: 0.97
acc:
case_type: accuracy
dataset_path: vllm-ascend/gsm8k-lite
request_conf: vllm_api_general_chat
dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt
max_out_len: 4096
batch_size: 64
baseline: 95
threshold: 5

View File

@@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import torch
from tests.ut.attention.utils import patch_distributed_groups
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm.distributed.parallel_state import GroupCoordinator
@@ -41,7 +42,7 @@ class TestAscendSFAMetadata(TestBase):
slot_mapping = torch.randn(100, 4, 1024)
seq_lens = torch.tensor([30, 50])
cum_query_lens = torch.tensor([0, 30, 80])
block_tables = torch.randint(0, 100, (100, 4))
block_table = torch.randint(0, 100, (100, 4))
rope_dim = 32
max_seq_len = int(seq_lens.max().item())
@@ -58,7 +59,7 @@ class TestAscendSFAMetadata(TestBase):
slot_mapping=slot_mapping,
seq_lens=seq_lens,
cum_query_lens=cum_query_lens,
block_tables=block_tables,
block_table=block_table,
sin=sin,
cos=cos,
num_input_tokens=num_input_tokens,
@@ -71,7 +72,7 @@ class TestAscendSFAMetadata(TestBase):
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))
self.assertTrue(torch.equal(metadata.cum_query_lens, cum_query_lens))
self.assertIs(metadata.block_tables, block_tables)
self.assertIs(metadata.block_table, block_table)
self.assertIs(metadata.sin, sin)
self.assertIs(metadata.cos, cos)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
@@ -135,6 +136,7 @@ class TestAscendSFAMetadataBuilder(TestBase):
self.patcher.stop()
self.parent_init_patcher.stop()
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_ascend_sfa_metadata_builder_default(self):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
@@ -160,6 +162,7 @@ class TestAscendSFAMetadataBuilder(TestBase):
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
@patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp")
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_ascend_sfa_metadata_builder_build(
self,
mock_enable_dsa_cp,
@@ -222,6 +225,7 @@ class TestAscendSFAMetadataBuilder(TestBase):
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config):
cfg = MagicMock()

View File

@@ -6,7 +6,7 @@ import torch_npu
import vllm.envs as envs_vllm
from torch import nn
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.distributed import get_dcp_group, get_pcp_group, get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
@@ -95,6 +95,12 @@ class DSACPContext:
actual_seq_lengths_key: torch.Tensor
@dataclass
class SFACPMetadata:
block_table_cp: torch.Tensor
valid_block_ids: torch.Tensor
@dataclass
class AscendSFAMetadata:
"""Metadata for MLACommon.
@@ -114,7 +120,7 @@ class AscendSFAMetadata:
slot_mapping: torch.Tensor
seq_lens: torch.Tensor
cum_query_lens: torch.Tensor
block_tables: torch.Tensor
block_table: torch.Tensor
sin: torch.Tensor
cos: torch.Tensor
@@ -127,6 +133,7 @@ class AscendSFAMetadata:
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
dsa_cp_context: DSACPContext | None = None
reshape_cache_event: torch.npu.Event = None
sfa_cp_metadata: SFACPMetadata | None = None
M = TypeVar("M", bound=AscendSFAMetadata)
@@ -178,6 +185,14 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device)
self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query)
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
@staticmethod
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
return ascend_chunked_prefill_workspace_size(vllm_config)
@@ -294,6 +309,22 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
actual_seq_lengths_key=actual_seq_lengths_key,
)
sfa_cp_metadata = None
if self.pcp_size * self.dcp_size > 1:
valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True)
num_blocks = valid_block_ids.shape[0]
# Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens.
# We assume that we can always get the correct kv_lens or kv index,
# so we omit the dirty value processing here.
block_table_cp = (
new_block_table.unsqueeze(-1).to(block_table)
+ (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table)
).reshape(block_table.shape[0], -1)
sfa_cp_metadata = SFACPMetadata(
block_table_cp=block_table_cp,
valid_block_ids=valid_block_ids,
)
return self.metadata_cls( # type: ignore
num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens,
@@ -303,10 +334,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
head_dim=self.model_config.get_head_size(),
attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config),
attn_state=common_attn_metadata.attn_state,
block_tables=block_table,
block_table=block_table,
sin=sin[:num_input_tokens],
cos=cos[:num_input_tokens],
dsa_cp_context=dsa_cp_context,
sfa_cp_metadata=sfa_cp_metadata,
)
def build_for_graph_capture(
@@ -417,6 +449,14 @@ class AscendSFAImpl(MLAAttentionImpl):
)
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
def process_weights_after_loading(self, act_dtype: torch.dtype):
# NOTE: We currently do not support quant kv_b_proj.
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
@@ -849,18 +889,28 @@ class AscendSFAImpl(MLAAttentionImpl):
need_gather_q_kv=need_gather_q_kv,
)
block_table = attn_metadata.block_table
kv = kv_cache[0]
key_rope = kv_cache[1]
if self.pcp_size * self.dcp_size > 1:
assert attn_metadata.sfa_cp_metadata is not None
valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids
kv = self.gather_kv_cross_cp(kv, valid_block_ids)
key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids)
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
query=ql_nope,
key=kv_cache[0],
value=kv_cache[0],
key=kv,
value=kv,
sparse_indices=topk_indices,
scale_value=self.scale,
sparse_block_size=1,
block_table=attn_metadata.block_tables,
block_table=block_table,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_kv=actual_seq_lengths_key,
query_rope=q_pe,
key_rope=kv_cache[1],
key_rope=key_rope,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
@@ -894,6 +944,15 @@ class AscendSFAImpl(MLAAttentionImpl):
return output_padded
def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor:
# Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!!
kv_cache = torch.index_select(kv_cache, 0, valid_block_ids)
if self.dcp_size > 1:
kv_cache = get_dcp_group().all_gather(kv_cache, 0)
if self.pcp_size > 1:
kv_cache = get_pcp_group().all_gather(kv_cache, 0)
return kv_cache
def indexer_select_pre_process(
self,
x: torch.Tensor,
@@ -969,11 +1028,16 @@ class AscendSFAImpl(MLAAttentionImpl):
weights, _ = self.weights_proj(x)
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv)
block_table = attn_metadata.block_tables
key = kv_cache[2]
block_table = attn_metadata.block_table
if self.pcp_size * self.dcp_size > 1:
assert attn_metadata.sfa_cp_metadata is not None
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
query=q,
key=kv_cache[2],
key=key,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,

View File

@@ -378,10 +378,11 @@ class NPUPlatform(Platform):
vllm_config.scheduler_config.enable_chunked_prefill = True
vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch
cp_size = parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size
if (
vllm_config.kv_transfer_config is not None
and cache_config.block_size != parallel_config.cp_kv_cache_interleave_size
and parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1
and cp_size > 1
):
raise AssertionError(
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) "
@@ -389,6 +390,20 @@ class NPUPlatform(Platform):
"needs to be equal if use pcp or dcp > 1 in P/D disaggregate and kv pool scenario."
)
use_sparse = (
model_config is not None
and model_config.hf_text_config is not None
and hasattr(model_config.hf_text_config, "index_topk")
)
if use_sparse and cp_size > 1 and parallel_config.cp_kv_cache_interleave_size != cache_config.block_size:
logger.warning_once(
"The current SFA's PCP&DCP implementation requires"
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size})"
f" == block_size({cache_config.block_size}). "
f"Override cp_kv_cache_interleave_size to {cache_config.block_size}."
)
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
if is_vl_model(vllm_config):
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))