[Feat](sfa,dcp) support dcp for sfa (#6563)

### What this PR does / why we need it?
This PR adds DCP support to the SFA backend.

Please note that due to operator constraints, the current implementation
has to all-gather the entire KV cache and modify the block table to
satisfy the operator input requirements. This results in significantly
increased communication overhead and peak memory usage. Therefore, this
is only a temporary workaround and will be refactored once the operator
provides proper support.

Additionally, because of the above limitations,
`cp_kv_cache_interleave_size` is currently required to be equal to
`block_size`. This restriction will also be removed after the refactor.

#### Test
accuracy test using DeepSeek-V3.2-Exp-W8A8 with dp2tp8dcp8

| dataset | version | metric | mode | vllm-api-general-stream |
|----- | ----- | ----- | ----- | -----|
| gsm8kdataset | - | accuracy | gen | 96.35 |

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu
2026-02-09 18:52:25 +08:00
committed by GitHub
parent 80e5812b39
commit cb7c419bc0
5 changed files with 190 additions and 13 deletions

View File

@@ -6,7 +6,7 @@ import torch_npu
import vllm.envs as envs_vllm
from torch import nn
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.distributed import get_dcp_group, get_pcp_group, get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
@@ -95,6 +95,12 @@ class DSACPContext:
actual_seq_lengths_key: torch.Tensor
@dataclass
class SFACPMetadata:
block_table_cp: torch.Tensor
valid_block_ids: torch.Tensor
@dataclass
class AscendSFAMetadata:
"""Metadata for MLACommon.
@@ -114,7 +120,7 @@ class AscendSFAMetadata:
slot_mapping: torch.Tensor
seq_lens: torch.Tensor
cum_query_lens: torch.Tensor
block_tables: torch.Tensor
block_table: torch.Tensor
sin: torch.Tensor
cos: torch.Tensor
@@ -127,6 +133,7 @@ class AscendSFAMetadata:
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
dsa_cp_context: DSACPContext | None = None
reshape_cache_event: torch.npu.Event = None
sfa_cp_metadata: SFACPMetadata | None = None
M = TypeVar("M", bound=AscendSFAMetadata)
@@ -178,6 +185,14 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device)
self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query)
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
@staticmethod
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
return ascend_chunked_prefill_workspace_size(vllm_config)
@@ -294,6 +309,22 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
actual_seq_lengths_key=actual_seq_lengths_key,
)
sfa_cp_metadata = None
if self.pcp_size * self.dcp_size > 1:
valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True)
num_blocks = valid_block_ids.shape[0]
# Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens.
# We assume that we can always get the correct kv_lens or kv index,
# so we omit the dirty value processing here.
block_table_cp = (
new_block_table.unsqueeze(-1).to(block_table)
+ (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table)
).reshape(block_table.shape[0], -1)
sfa_cp_metadata = SFACPMetadata(
block_table_cp=block_table_cp,
valid_block_ids=valid_block_ids,
)
return self.metadata_cls( # type: ignore
num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens,
@@ -303,10 +334,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
head_dim=self.model_config.get_head_size(),
attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config),
attn_state=common_attn_metadata.attn_state,
block_tables=block_table,
block_table=block_table,
sin=sin[:num_input_tokens],
cos=cos[:num_input_tokens],
dsa_cp_context=dsa_cp_context,
sfa_cp_metadata=sfa_cp_metadata,
)
def build_for_graph_capture(
@@ -417,6 +449,14 @@ class AscendSFAImpl(MLAAttentionImpl):
)
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
def process_weights_after_loading(self, act_dtype: torch.dtype):
# NOTE: We currently do not support quant kv_b_proj.
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
@@ -849,18 +889,28 @@ class AscendSFAImpl(MLAAttentionImpl):
need_gather_q_kv=need_gather_q_kv,
)
block_table = attn_metadata.block_table
kv = kv_cache[0]
key_rope = kv_cache[1]
if self.pcp_size * self.dcp_size > 1:
assert attn_metadata.sfa_cp_metadata is not None
valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids
kv = self.gather_kv_cross_cp(kv, valid_block_ids)
key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids)
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
query=ql_nope,
key=kv_cache[0],
value=kv_cache[0],
key=kv,
value=kv,
sparse_indices=topk_indices,
scale_value=self.scale,
sparse_block_size=1,
block_table=attn_metadata.block_tables,
block_table=block_table,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_kv=actual_seq_lengths_key,
query_rope=q_pe,
key_rope=kv_cache[1],
key_rope=key_rope,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
@@ -894,6 +944,15 @@ class AscendSFAImpl(MLAAttentionImpl):
return output_padded
def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor:
# Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!!
kv_cache = torch.index_select(kv_cache, 0, valid_block_ids)
if self.dcp_size > 1:
kv_cache = get_dcp_group().all_gather(kv_cache, 0)
if self.pcp_size > 1:
kv_cache = get_pcp_group().all_gather(kv_cache, 0)
return kv_cache
def indexer_select_pre_process(
self,
x: torch.Tensor,
@@ -969,11 +1028,16 @@ class AscendSFAImpl(MLAAttentionImpl):
weights, _ = self.weights_proj(x)
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv)
block_table = attn_metadata.block_tables
key = kv_cache[2]
block_table = attn_metadata.block_table
if self.pcp_size * self.dcp_size > 1:
assert attn_metadata.sfa_cp_metadata is not None
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
query=q,
key=kv_cache[2],
key=key,
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,

View File

@@ -378,10 +378,11 @@ class NPUPlatform(Platform):
vllm_config.scheduler_config.enable_chunked_prefill = True
vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch
cp_size = parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size
if (
vllm_config.kv_transfer_config is not None
and cache_config.block_size != parallel_config.cp_kv_cache_interleave_size
and parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1
and cp_size > 1
):
raise AssertionError(
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) "
@@ -389,6 +390,20 @@ class NPUPlatform(Platform):
"needs to be equal if use pcp or dcp > 1 in P/D disaggregate and kv pool scenario."
)
use_sparse = (
model_config is not None
and model_config.hf_text_config is not None
and hasattr(model_config.hf_text_config, "index_topk")
)
if use_sparse and cp_size > 1 and parallel_config.cp_kv_cache_interleave_size != cache_config.block_size:
logger.warning_once(
"The current SFA's PCP&DCP implementation requires"
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size})"
f" == block_size({cache_config.block_size}). "
f"Override cp_kv_cache_interleave_size to {cache_config.block_size}."
)
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
if is_vl_model(vllm_config):
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))