[Feat](sfa,dcp) support dcp for sfa (#6563)
### What this PR does / why we need it? This PR adds DCP support to the SFA backend. Please note that due to operator constraints, the current implementation has to all-gather the entire KV cache and modify the block table to satisfy the operator input requirements. This results in significantly increased communication overhead and peak memory usage. Therefore, this is only a temporary workaround and will be refactored once the operator provides proper support. Additionally, because of the above limitations, `cp_kv_cache_interleave_size` is currently required to be equal to `block_size`. This restriction will also be removed after the refactor. #### Test accuracy test using DeepSeek-V3.2-Exp-W8A8 with dp2tp8dcp8 | dataset | version | metric | mode | vllm-api-general-stream | |----- | ----- | ----- | ----- | -----| | gsm8kdataset | - | accuracy | gen | 96.35 | - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from torch import nn
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group, get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
|
||||
@@ -95,6 +95,12 @@ class DSACPContext:
|
||||
actual_seq_lengths_key: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFACPMetadata:
|
||||
block_table_cp: torch.Tensor
|
||||
valid_block_ids: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAMetadata:
|
||||
"""Metadata for MLACommon.
|
||||
@@ -114,7 +120,7 @@ class AscendSFAMetadata:
|
||||
slot_mapping: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
cum_query_lens: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
sin: torch.Tensor
|
||||
cos: torch.Tensor
|
||||
|
||||
@@ -127,6 +133,7 @@ class AscendSFAMetadata:
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
dsa_cp_context: DSACPContext | None = None
|
||||
reshape_cache_event: torch.npu.Event = None
|
||||
sfa_cp_metadata: SFACPMetadata | None = None
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
@@ -178,6 +185,14 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device)
|
||||
self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query)
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
|
||||
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
|
||||
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
|
||||
|
||||
@staticmethod
|
||||
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
|
||||
return ascend_chunked_prefill_workspace_size(vllm_config)
|
||||
@@ -294,6 +309,22 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
)
|
||||
|
||||
sfa_cp_metadata = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True)
|
||||
num_blocks = valid_block_ids.shape[0]
|
||||
# Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens.
|
||||
# We assume that we can always get the correct kv_lens or kv index,
|
||||
# so we omit the dirty value processing here.
|
||||
block_table_cp = (
|
||||
new_block_table.unsqueeze(-1).to(block_table)
|
||||
+ (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table)
|
||||
).reshape(block_table.shape[0], -1)
|
||||
sfa_cp_metadata = SFACPMetadata(
|
||||
block_table_cp=block_table_cp,
|
||||
valid_block_ids=valid_block_ids,
|
||||
)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -303,10 +334,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config),
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
block_tables=block_table,
|
||||
block_table=block_table,
|
||||
sin=sin[:num_input_tokens],
|
||||
cos=cos[:num_input_tokens],
|
||||
dsa_cp_context=dsa_cp_context,
|
||||
sfa_cp_metadata=sfa_cp_metadata,
|
||||
)
|
||||
|
||||
def build_for_graph_capture(
|
||||
@@ -417,6 +449,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
)
|
||||
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
|
||||
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
|
||||
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
# NOTE: We currently do not support quant kv_b_proj.
|
||||
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
|
||||
@@ -849,18 +889,28 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
need_gather_q_kv=need_gather_q_kv,
|
||||
)
|
||||
|
||||
block_table = attn_metadata.block_table
|
||||
kv = kv_cache[0]
|
||||
key_rope = kv_cache[1]
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids
|
||||
kv = self.gather_kv_cross_cp(kv, valid_block_ids)
|
||||
key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids)
|
||||
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
||||
|
||||
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
|
||||
query=ql_nope,
|
||||
key=kv_cache[0],
|
||||
value=kv_cache[0],
|
||||
key=kv,
|
||||
value=kv,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=attn_metadata.block_tables,
|
||||
block_table=block_table,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_key,
|
||||
query_rope=q_pe,
|
||||
key_rope=kv_cache[1],
|
||||
key_rope=key_rope,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
@@ -894,6 +944,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
return output_padded
|
||||
|
||||
def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor:
|
||||
# Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!!
|
||||
kv_cache = torch.index_select(kv_cache, 0, valid_block_ids)
|
||||
if self.dcp_size > 1:
|
||||
kv_cache = get_dcp_group().all_gather(kv_cache, 0)
|
||||
if self.pcp_size > 1:
|
||||
kv_cache = get_pcp_group().all_gather(kv_cache, 0)
|
||||
return kv_cache
|
||||
|
||||
def indexer_select_pre_process(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -969,11 +1028,16 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
weights, _ = self.weights_proj(x)
|
||||
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv)
|
||||
|
||||
block_table = attn_metadata.block_tables
|
||||
key = kv_cache[2]
|
||||
block_table = attn_metadata.block_table
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
|
||||
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
||||
|
||||
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
|
||||
query=q,
|
||||
key=kv_cache[2],
|
||||
key=key,
|
||||
weights=weights,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
|
||||
Reference in New Issue
Block a user