[Feat](sfa,dcp) support dcp for sfa (#6563)
### What this PR does / why we need it? This PR adds DCP support to the SFA backend. Please note that due to operator constraints, the current implementation has to all-gather the entire KV cache and modify the block table to satisfy the operator input requirements. This results in significantly increased communication overhead and peak memory usage. Therefore, this is only a temporary workaround and will be refactored once the operator provides proper support. Additionally, because of the above limitations, `cp_kv_cache_interleave_size` is currently required to be equal to `block_size`. This restriction will also be removed after the refactor. #### Test accuracy test using DeepSeek-V3.2-Exp-W8A8 with dp2tp8dcp8 | dataset | version | metric | mode | vllm-api-general-stream | |----- | ----- | ----- | ----- | -----| | gsm8kdataset | - | accuracy | gen | 96.35 | - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -77,6 +77,9 @@ jobs:
|
||||
- name: multi-node-qwenw8a8-2node-longseq
|
||||
config_file_path: Qwen3-235B-W8A8-longseq.yaml
|
||||
size: 2
|
||||
- name: multi-node-deepseek-V3_2-W8A8-cp
|
||||
config_file_path: DeepSeek-V3_2-W8A8-cp.yaml
|
||||
size: 2
|
||||
- name: multi-node-qwen-disagg-pd
|
||||
config_file_path: Qwen3-235B-disagg-pd.yaml
|
||||
size: 2
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
test_name: "test DeepSeek-V3.2-W8A8 for PCP&DCP"
|
||||
model: "vllm-ascend/DeepSeek-V3.2-W8A8"
|
||||
num_nodes: 2
|
||||
npu_per_node: 16
|
||||
env_common:
|
||||
HCCL_OP_EXPANSION_MODE: "AIV"
|
||||
|
||||
VLLM_USE_MODELSCOPE: true
|
||||
HCCL_BUFFSIZE: 1024
|
||||
SERVER_PORT: 8080
|
||||
OMP_PROC_BIND: false
|
||||
OMP_NUM_THREADS: 1
|
||||
PYTORCH_NPU_ALLOC_CONF: "expandable_segments:True"
|
||||
VLLM_ASCEND_ENABLE_FLASHCOMM1: 1
|
||||
ASCEND_A3_EBA_ENABLE: 1
|
||||
|
||||
|
||||
deployment:
|
||||
-
|
||||
server_cmd: >
|
||||
vllm serve vllm-ascend/DeepSeek-V3.2-W8A8
|
||||
--host 0.0.0.0
|
||||
--port $SERVER_PORT
|
||||
--data-parallel-size 4
|
||||
--data-parallel-size-local 2
|
||||
--data-parallel-address $LOCAL_IP
|
||||
--data-parallel-rpc-port 13399
|
||||
--tensor-parallel-size 8
|
||||
--decode-context-parallel-size 8
|
||||
--quantization ascend
|
||||
--seed 1024
|
||||
--enable-expert-parallel
|
||||
--max-num-seqs 16
|
||||
--max-model-len 8192
|
||||
--max-num-batched-tokens 4096
|
||||
--no-enable-prefix-caching
|
||||
--gpu-memory-utilization 0.85
|
||||
--trust-remote-code
|
||||
--speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}'
|
||||
--compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}'
|
||||
--additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}'
|
||||
--tokenizer-mode deepseek_v32
|
||||
--reasoning-parser deepseek_v3
|
||||
|
||||
-
|
||||
server_cmd: >
|
||||
vllm serve vllm-ascend/DeepSeek-V3.2-W8A8
|
||||
--headless
|
||||
--data-parallel-size 4
|
||||
--data-parallel-rpc-port 13399
|
||||
--data-parallel-size-local 2
|
||||
--data-parallel-start-rank 2
|
||||
--data-parallel-address $MASTER_IP
|
||||
--tensor-parallel-size 8
|
||||
--decode-context-parallel-size 8
|
||||
--quantization ascend
|
||||
--seed 1024
|
||||
--enable-expert-parallel
|
||||
--max-num-seqs 16
|
||||
--max-model-len 8192
|
||||
--max-num-batched-tokens 4096
|
||||
--no-enable-prefix-caching
|
||||
--gpu-memory-utilization 0.85
|
||||
--trust-remote-code
|
||||
--speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}'
|
||||
--compilation-config '{"cudagraph_capture_sizes": [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48], "cudagraph_mode": "FULL_DECODE_ONLY"}'
|
||||
--additional-config '{"layer_sharding": ["q_b_proj", "o_proj"]}'
|
||||
--tokenizer-mode deepseek_v32
|
||||
--reasoning-parser deepseek_v3
|
||||
benchmarks:
|
||||
perf:
|
||||
case_type: performance
|
||||
dataset_path: vllm-ascend/GSM8K-in3500-bs2800
|
||||
request_conf: vllm_api_stream_chat
|
||||
dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_str_perf
|
||||
num_prompts: 512
|
||||
max_out_len: 3000
|
||||
batch_size: 512
|
||||
request_rate: 11.2
|
||||
baseline: 1253.8466
|
||||
threshold: 0.97
|
||||
|
||||
acc:
|
||||
case_type: accuracy
|
||||
dataset_path: vllm-ascend/gsm8k-lite
|
||||
request_conf: vllm_api_general_chat
|
||||
dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt
|
||||
max_out_len: 4096
|
||||
batch_size: 64
|
||||
baseline: 95
|
||||
threshold: 5
|
||||
@@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.attention.utils import patch_distributed_groups
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
@@ -41,7 +42,7 @@ class TestAscendSFAMetadata(TestBase):
|
||||
slot_mapping = torch.randn(100, 4, 1024)
|
||||
seq_lens = torch.tensor([30, 50])
|
||||
cum_query_lens = torch.tensor([0, 30, 80])
|
||||
block_tables = torch.randint(0, 100, (100, 4))
|
||||
block_table = torch.randint(0, 100, (100, 4))
|
||||
|
||||
rope_dim = 32
|
||||
max_seq_len = int(seq_lens.max().item())
|
||||
@@ -58,7 +59,7 @@ class TestAscendSFAMetadata(TestBase):
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
cum_query_lens=cum_query_lens,
|
||||
block_tables=block_tables,
|
||||
block_table=block_table,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
num_input_tokens=num_input_tokens,
|
||||
@@ -71,7 +72,7 @@ class TestAscendSFAMetadata(TestBase):
|
||||
self.assertIs(metadata.slot_mapping, slot_mapping)
|
||||
self.assertTrue(torch.equal(metadata.seq_lens, seq_lens))
|
||||
self.assertTrue(torch.equal(metadata.cum_query_lens, cum_query_lens))
|
||||
self.assertIs(metadata.block_tables, block_tables)
|
||||
self.assertIs(metadata.block_table, block_table)
|
||||
self.assertIs(metadata.sin, sin)
|
||||
self.assertIs(metadata.cos, cos)
|
||||
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
|
||||
@@ -135,6 +136,7 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
self.patcher.stop()
|
||||
self.parent_init_patcher.stop()
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_ascend_sfa_metadata_builder_default(self):
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
@@ -160,6 +162,7 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||
@patch("vllm_ascend.attention.sfa_v1.enable_dsa_cp")
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_ascend_sfa_metadata_builder_build(
|
||||
self,
|
||||
mock_enable_dsa_cp,
|
||||
@@ -222,6 +225,7 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.sfa_v1.get_cos_and_sin_mla")
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_ascend_sfa_metadata_builder_build_for_graph_capture(
|
||||
self, mock_get_cos_and_sin_mla, mock_get_current_vllm_config):
|
||||
cfg = MagicMock()
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from torch import nn
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group, get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder
|
||||
@@ -95,6 +95,12 @@ class DSACPContext:
|
||||
actual_seq_lengths_key: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFACPMetadata:
|
||||
block_table_cp: torch.Tensor
|
||||
valid_block_ids: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAMetadata:
|
||||
"""Metadata for MLACommon.
|
||||
@@ -114,7 +120,7 @@ class AscendSFAMetadata:
|
||||
slot_mapping: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
cum_query_lens: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
sin: torch.Tensor
|
||||
cos: torch.Tensor
|
||||
|
||||
@@ -127,6 +133,7 @@ class AscendSFAMetadata:
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
dsa_cp_context: DSACPContext | None = None
|
||||
reshape_cache_event: torch.npu.Event = None
|
||||
sfa_cp_metadata: SFACPMetadata | None = None
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
@@ -178,6 +185,14 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device)
|
||||
self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query)
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
|
||||
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
|
||||
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
|
||||
|
||||
@staticmethod
|
||||
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
|
||||
return ascend_chunked_prefill_workspace_size(vllm_config)
|
||||
@@ -294,6 +309,22 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
)
|
||||
|
||||
sfa_cp_metadata = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True)
|
||||
num_blocks = valid_block_ids.shape[0]
|
||||
# Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens.
|
||||
# We assume that we can always get the correct kv_lens or kv index,
|
||||
# so we omit the dirty value processing here.
|
||||
block_table_cp = (
|
||||
new_block_table.unsqueeze(-1).to(block_table)
|
||||
+ (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table)
|
||||
).reshape(block_table.shape[0], -1)
|
||||
sfa_cp_metadata = SFACPMetadata(
|
||||
block_table_cp=block_table_cp,
|
||||
valid_block_ids=valid_block_ids,
|
||||
)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -303,10 +334,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config),
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
block_tables=block_table,
|
||||
block_table=block_table,
|
||||
sin=sin[:num_input_tokens],
|
||||
cos=cos[:num_input_tokens],
|
||||
dsa_cp_context=dsa_cp_context,
|
||||
sfa_cp_metadata=sfa_cp_metadata,
|
||||
)
|
||||
|
||||
def build_for_graph_capture(
|
||||
@@ -417,6 +449,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
)
|
||||
register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs)
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group().device_group if self.pcp_size > 1 else None
|
||||
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.dcp_size > 1 else 0
|
||||
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
# NOTE: We currently do not support quant kv_b_proj.
|
||||
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
|
||||
@@ -849,18 +889,28 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
need_gather_q_kv=need_gather_q_kv,
|
||||
)
|
||||
|
||||
block_table = attn_metadata.block_table
|
||||
kv = kv_cache[0]
|
||||
key_rope = kv_cache[1]
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids
|
||||
kv = self.gather_kv_cross_cp(kv, valid_block_ids)
|
||||
key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids)
|
||||
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
||||
|
||||
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
|
||||
query=ql_nope,
|
||||
key=kv_cache[0],
|
||||
value=kv_cache[0],
|
||||
key=kv,
|
||||
value=kv,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=attn_metadata.block_tables,
|
||||
block_table=block_table,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_key,
|
||||
query_rope=q_pe,
|
||||
key_rope=kv_cache[1],
|
||||
key_rope=key_rope,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
@@ -894,6 +944,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
return output_padded
|
||||
|
||||
def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor:
|
||||
# Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!!
|
||||
kv_cache = torch.index_select(kv_cache, 0, valid_block_ids)
|
||||
if self.dcp_size > 1:
|
||||
kv_cache = get_dcp_group().all_gather(kv_cache, 0)
|
||||
if self.pcp_size > 1:
|
||||
kv_cache = get_pcp_group().all_gather(kv_cache, 0)
|
||||
return kv_cache
|
||||
|
||||
def indexer_select_pre_process(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -969,11 +1028,16 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
weights, _ = self.weights_proj(x)
|
||||
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv)
|
||||
|
||||
block_table = attn_metadata.block_tables
|
||||
key = kv_cache[2]
|
||||
block_table = attn_metadata.block_table
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
|
||||
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
||||
|
||||
topk_indices = torch.ops._C_ascend.npu_lightning_indexer(
|
||||
query=q,
|
||||
key=kv_cache[2],
|
||||
key=key,
|
||||
weights=weights,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
|
||||
@@ -378,10 +378,11 @@ class NPUPlatform(Platform):
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = True
|
||||
vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch
|
||||
|
||||
cp_size = parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size
|
||||
if (
|
||||
vllm_config.kv_transfer_config is not None
|
||||
and cache_config.block_size != parallel_config.cp_kv_cache_interleave_size
|
||||
and parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1
|
||||
and cp_size > 1
|
||||
):
|
||||
raise AssertionError(
|
||||
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size}) "
|
||||
@@ -389,6 +390,20 @@ class NPUPlatform(Platform):
|
||||
"needs to be equal if use pcp or dcp > 1 in P/D disaggregate and kv pool scenario."
|
||||
)
|
||||
|
||||
use_sparse = (
|
||||
model_config is not None
|
||||
and model_config.hf_text_config is not None
|
||||
and hasattr(model_config.hf_text_config, "index_topk")
|
||||
)
|
||||
if use_sparse and cp_size > 1 and parallel_config.cp_kv_cache_interleave_size != cache_config.block_size:
|
||||
logger.warning_once(
|
||||
"The current SFA's PCP&DCP implementation requires"
|
||||
f"cp_kv_cache_interleave_size({parallel_config.cp_kv_cache_interleave_size})"
|
||||
f" == block_size({cache_config.block_size}). "
|
||||
f"Override cp_kv_cache_interleave_size to {cache_config.block_size}."
|
||||
)
|
||||
vllm_config.parallel_config.cp_kv_cache_interleave_size = cache_config.block_size
|
||||
|
||||
if is_vl_model(vllm_config):
|
||||
if bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", "0"))) or bool(
|
||||
int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", "0"))
|
||||
|
||||
Reference in New Issue
Block a user