[Feat]Adapt the graph mode (piecewise and full_decode_only) of PCP and DCP for DeepSeek v3.2. (#6940)

### What this PR does / why we need it?
Adapt the graph mode (piecewise and full_decode_only) of PCP and DCP for
DeepSeek v3.2.

### How was this patch tested?
Test output:

{"object":"text_completion","model":"deepeek_v3","choices":[{"index":0,"text":"
the head of state and head of government of the United States,
indirectly elected to a four-year term by the American people through
the Electoral College. The officeholder leads the executive branch of
the federal government and is the commander-in-chief of the United
States","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":1,"text":"
Paris. This is the largest city in France and its main political,
cultural and commercial center. The modern location of the city is the
north of the central part of the country, on the banks of the Seine
River Seine River Seine in
3\n\n","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":2,"text":"
now\n\n# AI future is now\n\nThe world is changing at a rapid pace, and
artificial intelligence (AI) is at the forefront of this transformation.
From self-driving cars to virtual assistants, AI is already making a
significant impact on our daily
lives","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":3,"text":"
a 3rd year student at the University of Lincoln studying Media
Production. This blog is about my work throughout my final year on the
course.\n\n## Tuesday 3 May 2016\n### Final Major Project -
Evaluation\n\nFor my final project
I","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":27,"total_tokens":227,"completion_tokens":200,"prompt_tokens_details":null},"kv_transfer_params":null}

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: xiaocongtou6 <2066962956@qq.com>
Signed-off-by: xiaocongtou6 <105542647+xiaocongtou6@users.noreply.github.com>
This commit is contained in:
xiaocongtou6
2026-03-06 16:10:24 +08:00
committed by GitHub
parent a813eadd2d
commit bc0fd7ca72
3 changed files with 60 additions and 33 deletions

View File

@@ -25,6 +25,7 @@ from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free
os.environ["HCCL_BUFFSIZE"] = "768"
@wait_until_npu_memory_free()
def test_models_pcp_dcp_basic():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
@@ -59,12 +60,12 @@ def test_models_pcp_dcp_basic():
model = "vllm-ascend/DeepSeek-V3.2-W8A8-Pruning"
with VllmRunner(
model,
enforce_eager=True,
max_model_len=1024,
tensor_parallel_size=2,
prefill_context_parallel_size=2,
decode_context_parallel_size=2,
enable_expert_parallel=True,
gpu_memory_utilization=0.2,
block_size=128,
quantization="ascend",
) as runner:
@@ -84,6 +85,7 @@ def test_models_pcp_dcp_basic():
runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_models_pcp_dcp_full_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
@@ -121,6 +123,7 @@ def test_models_pcp_dcp_full_graph():
runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_models_pcp_dcp_piece_wise():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
@@ -172,6 +175,7 @@ def test_pcp_basic():
runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_pcp_full_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
@@ -195,6 +199,7 @@ def test_pcp_full_graph():
runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_pcp_piece_wise():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
@@ -214,6 +219,7 @@ def test_pcp_piece_wise():
runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_dcp_basic():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
@@ -233,6 +239,7 @@ def test_dcp_basic():
runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_dcp_full_graph():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",
@@ -256,6 +263,7 @@ def test_dcp_full_graph():
runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_dcp_piece_wise():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",

View File

@@ -34,6 +34,7 @@ class AscendPCPMetadata:
block_table_cp: torch.Tensor = None
valid_block_ids: torch.Tensor = None
prefill_q_cum_seqlens: torch.Tensor = None
block_arange: torch.Tensor = None
@dataclass

View File

@@ -5,6 +5,7 @@ import torch
import torch_npu
from vllm.config import VllmConfig
from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.forward_context import get_forward_context
from vllm.triton_utils import HAS_TRITON
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata
@@ -55,6 +56,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
dtype=torch.int32,
device=device,
)
self.block_arange_buffer = torch.arange(self.pcp_size * self.dcp_size, dtype=torch.int32, device=device)
def build(
self,
@@ -70,20 +72,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == common_attn_metadata.num_actual_tokens
block_table = metadata_cls.block_table
valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True)
num_blocks = valid_block_ids.shape[0]
# Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens.
# We assume that we can always get the correct kv_lens or kv index,
# so we omit the dirty value processing here.
block_table_cp = (
new_block_table.unsqueeze(-1).to(block_table)
+ (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table)
).reshape(block_table.shape[0], -1)
sfa_cp_metadata = self.build_cp_metadata(
block_table_cp, valid_block_ids, metadata_cls.seq_lens, common_attn_metadata
)
sfa_cp_metadata = self.build_cp_metadata(self.block_arange_buffer, metadata_cls.seq_lens, common_attn_metadata)
metadata_cls.num_decode_tokens = num_decode_tokens
metadata_cls.num_decodes = num_decodes
metadata_cls.num_prefills = num_prefills
@@ -127,8 +116,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
def build_cp_metadata(
self,
block_table_cp: torch.Tensor,
valid_block_ids: torch.Tensor,
block_arange: torch.Tensor,
seq_lens: torch.Tensor,
common_attn_metadata: AscendCommonAttentionMetadata,
) -> AscendPCPMetadata | None:
@@ -144,8 +132,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
head_attn_nomask_seqlens=q_head_kv_lens,
tail_attn_nomask_seqlens=q_tail_kv_lens,
pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx,
block_table_cp=block_table_cp,
valid_block_ids=valid_block_ids,
block_arange=block_arange,
)
@@ -198,16 +185,17 @@ class AscendSFACPImpl(AscendSFAImpl):
kv = kv_cache[0]
key_rope = kv_cache[1]
block_table = attn_metadata.block_table
assert attn_metadata.sfa_cp_metadata is not None
valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids
kv = self.gather_kv_cross_cp(kv, valid_block_ids)
key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids)
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
block_arange = attn_metadata.sfa_cp_metadata.block_arange
kv, block_table = self.gather_kv_cross_cp(kv, block_table, block_arange)
key_rope, _ = self.gather_kv_cross_cp(key_rope)
assert block_table is not None
if self.pcp_size == 1:
return self._execute_sparse_flash_attention(
attn_output = self._execute_sparse_flash_attention(
ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
)
return self._align_to_graph_bucket_tokens(attn_output, attn_metadata)
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills
@@ -225,7 +213,7 @@ class AscendSFACPImpl(AscendSFAImpl):
)
if num_prefills < 1:
return decode_attn_out
return self._align_to_graph_bucket_tokens(decode_attn_out, attn_metadata)
# q split for head and tail
q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx
@@ -266,7 +254,30 @@ class AscendSFACPImpl(AscendSFAImpl):
if decode_attn_out is not None:
attn_output = torch.cat([decode_attn_out, attn_output], dim=0)
return attn_output
return self._align_to_graph_bucket_tokens(attn_output, attn_metadata)
def _align_to_graph_bucket_tokens(self, attn_output: torch.Tensor | None, attn_metadata: M) -> torch.Tensor | None:
if attn_output is None:
return None
# In graph/piecewise mode, output buffer uses graph bucket token size
# (forward_context.num_tokens), while PCP path may compute only valid
# tokens. Align to the larger one to avoid later write-back mismatch.
forward_context = get_forward_context()
target_tokens = max(
attn_metadata.num_input_tokens,
forward_context.num_tokens if forward_context is not None else 0,
)
if attn_output.shape[0] == target_tokens:
return attn_output
aligned = torch.zeros(
(target_tokens, *attn_output.shape[1:]),
dtype=attn_output.dtype,
device=attn_output.device,
)
valid_tokens = min(attn_output.shape[0], target_tokens)
aligned[:valid_tokens] = attn_output[:valid_tokens]
return aligned
def _execute_sparse_flash_attention(
self, ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
@@ -289,14 +300,20 @@ class AscendSFACPImpl(AscendSFAImpl):
)
return attn_output
def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor:
def gather_kv_cross_cp(
self, kv_cache: torch.Tensor, block_tables: torch.Tensor | None = None, block_arange: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!!
kv_cache = torch.index_select(kv_cache, 0, valid_block_ids)
block_num = kv_cache.shape[0]
if self.dcp_size > 1:
kv_cache = get_dcp_group().all_gather(kv_cache, 0)
if self.pcp_size > 1:
kv_cache = get_pcp_group().all_gather(kv_cache, 0)
return kv_cache
if block_tables is not None and block_arange is not None:
block_tables = (
block_tables.unsqueeze(-1) + (block_arange * block_num).view(1, 1, -1).to(block_tables)
).reshape(block_tables.shape[0], -1)
return kv_cache, block_tables
def indexer_select_post_process(
self,
@@ -330,9 +347,11 @@ class AscendSFACPImpl(AscendSFAImpl):
q = q_li
key = kv_cache[2]
block_table = attn_metadata.block_table
assert attn_metadata.sfa_cp_metadata is not None
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
block_arange = attn_metadata.sfa_cp_metadata.block_arange
key, block_table = self.gather_kv_cross_cp(key, block_table, block_arange)
if self.pcp_size == 1:
return self._execute_indexer_select(
@@ -353,7 +372,6 @@ class AscendSFACPImpl(AscendSFAImpl):
actual_seq_lengths_key[:num_decodes],
block_table[:num_decodes],
)
# prefill compute
if num_prefills == 0:
return decode_topk_indices