[Feat]Adapt the graph mode (piecewise and full_decode_only) of PCP and DCP for DeepSeek v3.2. (#6940)

### What this PR does / why we need it?
Adapt the graph mode (piecewise and full_decode_only) of PCP and DCP for
DeepSeek v3.2.

### How was this patch tested?
Test output:

{"object":"text_completion","model":"deepeek_v3","choices":[{"index":0,"text":"
the head of state and head of government of the United States,
indirectly elected to a four-year term by the American people through
the Electoral College. The officeholder leads the executive branch of
the federal government and is the commander-in-chief of the United
States","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":1,"text":"
Paris. This is the largest city in France and its main political,
cultural and commercial center. The modern location of the city is the
north of the central part of the country, on the banks of the Seine
River Seine River Seine in
3\n\n","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":2,"text":"
now\n\n# AI future is now\n\nThe world is changing at a rapid pace, and
artificial intelligence (AI) is at the forefront of this transformation.
From self-driving cars to virtual assistants, AI is already making a
significant impact on our daily
lives","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":3,"text":"
a 3rd year student at the University of Lincoln studying Media
Production. This blog is about my work throughout my final year on the
course.\n\n## Tuesday 3 May 2016\n### Final Major Project -
Evaluation\n\nFor my final project
I","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":27,"total_tokens":227,"completion_tokens":200,"prompt_tokens_details":null},"kv_transfer_params":null}

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: xiaocongtou6 <2066962956@qq.com>
Signed-off-by: xiaocongtou6 <105542647+xiaocongtou6@users.noreply.github.com>
This commit is contained in:
xiaocongtou6
2026-03-06 16:10:24 +08:00
committed by GitHub
parent a813eadd2d
commit bc0fd7ca72
3 changed files with 60 additions and 33 deletions

View File

@@ -25,6 +25,7 @@ from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free
os.environ["HCCL_BUFFSIZE"] = "768" os.environ["HCCL_BUFFSIZE"] = "768"
@wait_until_npu_memory_free()
def test_models_pcp_dcp_basic(): def test_models_pcp_dcp_basic():
prompts = [ prompts = [
"The capital of France is", "Hello, my name is Tom, I am", "The capital of France is", "Hello, my name is Tom, I am",
@@ -59,12 +60,12 @@ def test_models_pcp_dcp_basic():
model = "vllm-ascend/DeepSeek-V3.2-W8A8-Pruning" model = "vllm-ascend/DeepSeek-V3.2-W8A8-Pruning"
with VllmRunner( with VllmRunner(
model, model,
enforce_eager=True,
max_model_len=1024, max_model_len=1024,
tensor_parallel_size=2, tensor_parallel_size=2,
prefill_context_parallel_size=2, prefill_context_parallel_size=2,
decode_context_parallel_size=2, decode_context_parallel_size=2,
enable_expert_parallel=True, enable_expert_parallel=True,
gpu_memory_utilization=0.2,
block_size=128, block_size=128,
quantization="ascend", quantization="ascend",
) as runner: ) as runner:
@@ -84,6 +85,7 @@ def test_models_pcp_dcp_basic():
runner.model.generate(prompts, sampling_params) runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_models_pcp_dcp_full_graph(): def test_models_pcp_dcp_full_graph():
prompts = [ prompts = [
"The capital of France is", "Hello, my name is Tom, I am", "The capital of France is", "Hello, my name is Tom, I am",
@@ -121,6 +123,7 @@ def test_models_pcp_dcp_full_graph():
runner.model.generate(prompts, sampling_params) runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_models_pcp_dcp_piece_wise(): def test_models_pcp_dcp_piece_wise():
prompts = [ prompts = [
"The capital of France is", "Hello, my name is Tom, I am", "The capital of France is", "Hello, my name is Tom, I am",
@@ -172,6 +175,7 @@ def test_pcp_basic():
runner.model.generate(prompts, sampling_params) runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_pcp_full_graph(): def test_pcp_full_graph():
prompts = [ prompts = [
"The capital of France is", "Hello, my name is Tom, I am", "The capital of France is", "Hello, my name is Tom, I am",
@@ -195,6 +199,7 @@ def test_pcp_full_graph():
runner.model.generate(prompts, sampling_params) runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_pcp_piece_wise(): def test_pcp_piece_wise():
prompts = [ prompts = [
"The capital of France is", "Hello, my name is Tom, I am", "The capital of France is", "Hello, my name is Tom, I am",
@@ -214,6 +219,7 @@ def test_pcp_piece_wise():
runner.model.generate(prompts, sampling_params) runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_dcp_basic(): def test_dcp_basic():
prompts = [ prompts = [
"The capital of France is", "Hello, my name is Tom, I am", "The capital of France is", "Hello, my name is Tom, I am",
@@ -233,6 +239,7 @@ def test_dcp_basic():
runner.model.generate(prompts, sampling_params) runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_dcp_full_graph(): def test_dcp_full_graph():
prompts = [ prompts = [
"The capital of France is", "Hello, my name is Tom, I am", "The capital of France is", "Hello, my name is Tom, I am",
@@ -256,6 +263,7 @@ def test_dcp_full_graph():
runner.model.generate(prompts, sampling_params) runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_dcp_piece_wise(): def test_dcp_piece_wise():
prompts = [ prompts = [
"The capital of France is", "Hello, my name is Tom, I am", "The capital of France is", "Hello, my name is Tom, I am",

View File

@@ -34,6 +34,7 @@ class AscendPCPMetadata:
block_table_cp: torch.Tensor = None block_table_cp: torch.Tensor = None
valid_block_ids: torch.Tensor = None valid_block_ids: torch.Tensor = None
prefill_q_cum_seqlens: torch.Tensor = None prefill_q_cum_seqlens: torch.Tensor = None
block_arange: torch.Tensor = None
@dataclass @dataclass

View File

@@ -5,6 +5,7 @@ import torch
import torch_npu import torch_npu
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_dcp_group, get_pcp_group from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.forward_context import get_forward_context
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata
@@ -55,6 +56,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self.block_arange_buffer = torch.arange(self.pcp_size * self.dcp_size, dtype=torch.int32, device=device)
def build( def build(
self, self,
@@ -70,20 +72,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
assert num_decodes + num_prefills == num_reqs assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == common_attn_metadata.num_actual_tokens assert num_decode_tokens + num_prefill_tokens == common_attn_metadata.num_actual_tokens
block_table = metadata_cls.block_table sfa_cp_metadata = self.build_cp_metadata(self.block_arange_buffer, metadata_cls.seq_lens, common_attn_metadata)
valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True)
num_blocks = valid_block_ids.shape[0]
# Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens.
# We assume that we can always get the correct kv_lens or kv index,
# so we omit the dirty value processing here.
block_table_cp = (
new_block_table.unsqueeze(-1).to(block_table)
+ (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table)
).reshape(block_table.shape[0], -1)
sfa_cp_metadata = self.build_cp_metadata(
block_table_cp, valid_block_ids, metadata_cls.seq_lens, common_attn_metadata
)
metadata_cls.num_decode_tokens = num_decode_tokens metadata_cls.num_decode_tokens = num_decode_tokens
metadata_cls.num_decodes = num_decodes metadata_cls.num_decodes = num_decodes
metadata_cls.num_prefills = num_prefills metadata_cls.num_prefills = num_prefills
@@ -127,8 +116,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
def build_cp_metadata( def build_cp_metadata(
self, self,
block_table_cp: torch.Tensor, block_arange: torch.Tensor,
valid_block_ids: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
common_attn_metadata: AscendCommonAttentionMetadata, common_attn_metadata: AscendCommonAttentionMetadata,
) -> AscendPCPMetadata | None: ) -> AscendPCPMetadata | None:
@@ -144,8 +132,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
head_attn_nomask_seqlens=q_head_kv_lens, head_attn_nomask_seqlens=q_head_kv_lens,
tail_attn_nomask_seqlens=q_tail_kv_lens, tail_attn_nomask_seqlens=q_tail_kv_lens,
pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx, pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx,
block_table_cp=block_table_cp, block_arange=block_arange,
valid_block_ids=valid_block_ids,
) )
@@ -198,16 +185,17 @@ class AscendSFACPImpl(AscendSFAImpl):
kv = kv_cache[0] kv = kv_cache[0]
key_rope = kv_cache[1] key_rope = kv_cache[1]
block_table = attn_metadata.block_table
assert attn_metadata.sfa_cp_metadata is not None assert attn_metadata.sfa_cp_metadata is not None
valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids block_arange = attn_metadata.sfa_cp_metadata.block_arange
kv = self.gather_kv_cross_cp(kv, valid_block_ids) kv, block_table = self.gather_kv_cross_cp(kv, block_table, block_arange)
key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids) key_rope, _ = self.gather_kv_cross_cp(key_rope)
block_table = attn_metadata.sfa_cp_metadata.block_table_cp assert block_table is not None
if self.pcp_size == 1: if self.pcp_size == 1:
return self._execute_sparse_flash_attention( attn_output = self._execute_sparse_flash_attention(
ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
) )
return self._align_to_graph_bucket_tokens(attn_output, attn_metadata)
num_decodes = attn_metadata.num_decodes num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills num_prefills = attn_metadata.num_prefills
@@ -225,7 +213,7 @@ class AscendSFACPImpl(AscendSFAImpl):
) )
if num_prefills < 1: if num_prefills < 1:
return decode_attn_out return self._align_to_graph_bucket_tokens(decode_attn_out, attn_metadata)
# q split for head and tail # q split for head and tail
q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx
@@ -266,7 +254,30 @@ class AscendSFACPImpl(AscendSFAImpl):
if decode_attn_out is not None: if decode_attn_out is not None:
attn_output = torch.cat([decode_attn_out, attn_output], dim=0) attn_output = torch.cat([decode_attn_out, attn_output], dim=0)
return attn_output return self._align_to_graph_bucket_tokens(attn_output, attn_metadata)
def _align_to_graph_bucket_tokens(self, attn_output: torch.Tensor | None, attn_metadata: M) -> torch.Tensor | None:
if attn_output is None:
return None
# In graph/piecewise mode, output buffer uses graph bucket token size
# (forward_context.num_tokens), while PCP path may compute only valid
# tokens. Align to the larger one to avoid later write-back mismatch.
forward_context = get_forward_context()
target_tokens = max(
attn_metadata.num_input_tokens,
forward_context.num_tokens if forward_context is not None else 0,
)
if attn_output.shape[0] == target_tokens:
return attn_output
aligned = torch.zeros(
(target_tokens, *attn_output.shape[1:]),
dtype=attn_output.dtype,
device=attn_output.device,
)
valid_tokens = min(attn_output.shape[0], target_tokens)
aligned[:valid_tokens] = attn_output[:valid_tokens]
return aligned
def _execute_sparse_flash_attention( def _execute_sparse_flash_attention(
self, ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key self, ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
@@ -289,14 +300,20 @@ class AscendSFACPImpl(AscendSFAImpl):
) )
return attn_output return attn_output
def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor: def gather_kv_cross_cp(
self, kv_cache: torch.Tensor, block_tables: torch.Tensor | None = None, block_arange: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!! # Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!!
kv_cache = torch.index_select(kv_cache, 0, valid_block_ids) block_num = kv_cache.shape[0]
if self.dcp_size > 1: if self.dcp_size > 1:
kv_cache = get_dcp_group().all_gather(kv_cache, 0) kv_cache = get_dcp_group().all_gather(kv_cache, 0)
if self.pcp_size > 1: if self.pcp_size > 1:
kv_cache = get_pcp_group().all_gather(kv_cache, 0) kv_cache = get_pcp_group().all_gather(kv_cache, 0)
return kv_cache if block_tables is not None and block_arange is not None:
block_tables = (
block_tables.unsqueeze(-1) + (block_arange * block_num).view(1, 1, -1).to(block_tables)
).reshape(block_tables.shape[0], -1)
return kv_cache, block_tables
def indexer_select_post_process( def indexer_select_post_process(
self, self,
@@ -330,9 +347,11 @@ class AscendSFACPImpl(AscendSFAImpl):
q = q_li q = q_li
key = kv_cache[2] key = kv_cache[2]
block_table = attn_metadata.block_table
assert attn_metadata.sfa_cp_metadata is not None assert attn_metadata.sfa_cp_metadata is not None
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids) block_arange = attn_metadata.sfa_cp_metadata.block_arange
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
key, block_table = self.gather_kv_cross_cp(key, block_table, block_arange)
if self.pcp_size == 1: if self.pcp_size == 1:
return self._execute_indexer_select( return self._execute_indexer_select(
@@ -353,7 +372,6 @@ class AscendSFACPImpl(AscendSFAImpl):
actual_seq_lengths_key[:num_decodes], actual_seq_lengths_key[:num_decodes],
block_table[:num_decodes], block_table[:num_decodes],
) )
# prefill compute # prefill compute
if num_prefills == 0: if num_prefills == 0:
return decode_topk_indices return decode_topk_indices