[Feat]Adapt the graph mode (piecewise and full_decode_only) of PCP and DCP for DeepSeek v3.2. (#6940)
### What this PR does / why we need it?
Adapt the graph mode (piecewise and full_decode_only) of PCP and DCP for
DeepSeek v3.2.
### How was this patch tested?
Test output:
{"object":"text_completion","model":"deepeek_v3","choices":[{"index":0,"text":"
the head of state and head of government of the United States,
indirectly elected to a four-year term by the American people through
the Electoral College. The officeholder leads the executive branch of
the federal government and is the commander-in-chief of the United
States","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":1,"text":"
Paris. This is the largest city in France and its main political,
cultural and commercial center. The modern location of the city is the
north of the central part of the country, on the banks of the Seine
River Seine River Seine in
3\n\n","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":2,"text":"
now\n\n# AI future is now\n\nThe world is changing at a rapid pace, and
artificial intelligence (AI) is at the forefront of this transformation.
From self-driving cars to virtual assistants, AI is already making a
significant impact on our daily
lives","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null},{"index":3,"text":"
a 3rd year student at the University of Lincoln studying Media
Production. This blog is about my work throughout my final year on the
course.\n\n## Tuesday 3 May 2016\n### Final Major Project -
Evaluation\n\nFor my final project
I","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":27,"total_tokens":227,"completion_tokens":200,"prompt_tokens_details":null},"kv_transfer_params":null}
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: xiaocongtou6 <2066962956@qq.com>
Signed-off-by: xiaocongtou6 <105542647+xiaocongtou6@users.noreply.github.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free
|
||||
os.environ["HCCL_BUFFSIZE"] = "768"
|
||||
|
||||
|
||||
@wait_until_npu_memory_free()
|
||||
def test_models_pcp_dcp_basic():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
@@ -59,12 +60,12 @@ def test_models_pcp_dcp_basic():
|
||||
model = "vllm-ascend/DeepSeek-V3.2-W8A8-Pruning"
|
||||
with VllmRunner(
|
||||
model,
|
||||
enforce_eager=True,
|
||||
max_model_len=1024,
|
||||
tensor_parallel_size=2,
|
||||
prefill_context_parallel_size=2,
|
||||
decode_context_parallel_size=2,
|
||||
enable_expert_parallel=True,
|
||||
gpu_memory_utilization=0.2,
|
||||
block_size=128,
|
||||
quantization="ascend",
|
||||
) as runner:
|
||||
@@ -84,6 +85,7 @@ def test_models_pcp_dcp_basic():
|
||||
runner.model.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@wait_until_npu_memory_free()
|
||||
def test_models_pcp_dcp_full_graph():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
@@ -121,6 +123,7 @@ def test_models_pcp_dcp_full_graph():
|
||||
runner.model.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@wait_until_npu_memory_free()
|
||||
def test_models_pcp_dcp_piece_wise():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
@@ -172,6 +175,7 @@ def test_pcp_basic():
|
||||
runner.model.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@wait_until_npu_memory_free()
|
||||
def test_pcp_full_graph():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
@@ -195,6 +199,7 @@ def test_pcp_full_graph():
|
||||
runner.model.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@wait_until_npu_memory_free()
|
||||
def test_pcp_piece_wise():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
@@ -214,6 +219,7 @@ def test_pcp_piece_wise():
|
||||
runner.model.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@wait_until_npu_memory_free()
|
||||
def test_dcp_basic():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
@@ -233,6 +239,7 @@ def test_dcp_basic():
|
||||
runner.model.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@wait_until_npu_memory_free()
|
||||
def test_dcp_full_graph():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
@@ -256,6 +263,7 @@ def test_dcp_full_graph():
|
||||
runner.model.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@wait_until_npu_memory_free()
|
||||
def test_dcp_piece_wise():
|
||||
prompts = [
|
||||
"The capital of France is", "Hello, my name is Tom, I am",
|
||||
|
||||
@@ -34,6 +34,7 @@ class AscendPCPMetadata:
|
||||
block_table_cp: torch.Tensor = None
|
||||
valid_block_ids: torch.Tensor = None
|
||||
prefill_q_cum_seqlens: torch.Tensor = None
|
||||
block_arange: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata
|
||||
@@ -55,6 +56,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_arange_buffer = torch.arange(self.pcp_size * self.dcp_size, dtype=torch.int32, device=device)
|
||||
|
||||
def build(
|
||||
self,
|
||||
@@ -70,20 +72,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == common_attn_metadata.num_actual_tokens
|
||||
|
||||
block_table = metadata_cls.block_table
|
||||
valid_block_ids, new_block_table = block_table.flatten().unique(return_inverse=True)
|
||||
num_blocks = valid_block_ids.shape[0]
|
||||
# Note(qcs): `block_table_cp` will have dirty values in the part beyond kv_lens.
|
||||
# We assume that we can always get the correct kv_lens or kv index,
|
||||
# so we omit the dirty value processing here.
|
||||
block_table_cp = (
|
||||
new_block_table.unsqueeze(-1).to(block_table)
|
||||
+ (torch.arange(self.pcp_size * self.dcp_size) * num_blocks).view(1, 1, -1).to(block_table)
|
||||
).reshape(block_table.shape[0], -1)
|
||||
|
||||
sfa_cp_metadata = self.build_cp_metadata(
|
||||
block_table_cp, valid_block_ids, metadata_cls.seq_lens, common_attn_metadata
|
||||
)
|
||||
sfa_cp_metadata = self.build_cp_metadata(self.block_arange_buffer, metadata_cls.seq_lens, common_attn_metadata)
|
||||
metadata_cls.num_decode_tokens = num_decode_tokens
|
||||
metadata_cls.num_decodes = num_decodes
|
||||
metadata_cls.num_prefills = num_prefills
|
||||
@@ -127,8 +116,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
|
||||
|
||||
def build_cp_metadata(
|
||||
self,
|
||||
block_table_cp: torch.Tensor,
|
||||
valid_block_ids: torch.Tensor,
|
||||
block_arange: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
) -> AscendPCPMetadata | None:
|
||||
@@ -144,8 +132,7 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
|
||||
head_attn_nomask_seqlens=q_head_kv_lens,
|
||||
tail_attn_nomask_seqlens=q_tail_kv_lens,
|
||||
pcp_allgather_restore_idx=common_long_seq_metadata.pcp_allgather_restore_idx,
|
||||
block_table_cp=block_table_cp,
|
||||
valid_block_ids=valid_block_ids,
|
||||
block_arange=block_arange,
|
||||
)
|
||||
|
||||
|
||||
@@ -198,16 +185,17 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
kv = kv_cache[0]
|
||||
key_rope = kv_cache[1]
|
||||
|
||||
block_table = attn_metadata.block_table
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
valid_block_ids = attn_metadata.sfa_cp_metadata.valid_block_ids
|
||||
kv = self.gather_kv_cross_cp(kv, valid_block_ids)
|
||||
key_rope = self.gather_kv_cross_cp(key_rope, valid_block_ids)
|
||||
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
||||
|
||||
block_arange = attn_metadata.sfa_cp_metadata.block_arange
|
||||
kv, block_table = self.gather_kv_cross_cp(kv, block_table, block_arange)
|
||||
key_rope, _ = self.gather_kv_cross_cp(key_rope)
|
||||
assert block_table is not None
|
||||
if self.pcp_size == 1:
|
||||
return self._execute_sparse_flash_attention(
|
||||
attn_output = self._execute_sparse_flash_attention(
|
||||
ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
|
||||
)
|
||||
return self._align_to_graph_bucket_tokens(attn_output, attn_metadata)
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
@@ -225,7 +213,7 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
)
|
||||
|
||||
if num_prefills < 1:
|
||||
return decode_attn_out
|
||||
return self._align_to_graph_bucket_tokens(decode_attn_out, attn_metadata)
|
||||
|
||||
# q split for head and tail
|
||||
q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx
|
||||
@@ -266,7 +254,30 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
|
||||
if decode_attn_out is not None:
|
||||
attn_output = torch.cat([decode_attn_out, attn_output], dim=0)
|
||||
return attn_output
|
||||
return self._align_to_graph_bucket_tokens(attn_output, attn_metadata)
|
||||
|
||||
def _align_to_graph_bucket_tokens(self, attn_output: torch.Tensor | None, attn_metadata: M) -> torch.Tensor | None:
|
||||
if attn_output is None:
|
||||
return None
|
||||
# In graph/piecewise mode, output buffer uses graph bucket token size
|
||||
# (forward_context.num_tokens), while PCP path may compute only valid
|
||||
# tokens. Align to the larger one to avoid later write-back mismatch.
|
||||
forward_context = get_forward_context()
|
||||
target_tokens = max(
|
||||
attn_metadata.num_input_tokens,
|
||||
forward_context.num_tokens if forward_context is not None else 0,
|
||||
)
|
||||
|
||||
if attn_output.shape[0] == target_tokens:
|
||||
return attn_output
|
||||
aligned = torch.zeros(
|
||||
(target_tokens, *attn_output.shape[1:]),
|
||||
dtype=attn_output.dtype,
|
||||
device=attn_output.device,
|
||||
)
|
||||
valid_tokens = min(attn_output.shape[0], target_tokens)
|
||||
aligned[:valid_tokens] = attn_output[:valid_tokens]
|
||||
return aligned
|
||||
|
||||
def _execute_sparse_flash_attention(
|
||||
self, ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
|
||||
@@ -289,14 +300,20 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tensor) -> torch.Tensor:
|
||||
def gather_kv_cross_cp(
|
||||
self, kv_cache: torch.Tensor, block_tables: torch.Tensor | None = None, block_arange: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# Note(qcs): we need set kv_cache_interleave_size = block_size for sfa!!!
|
||||
kv_cache = torch.index_select(kv_cache, 0, valid_block_ids)
|
||||
block_num = kv_cache.shape[0]
|
||||
if self.dcp_size > 1:
|
||||
kv_cache = get_dcp_group().all_gather(kv_cache, 0)
|
||||
if self.pcp_size > 1:
|
||||
kv_cache = get_pcp_group().all_gather(kv_cache, 0)
|
||||
return kv_cache
|
||||
if block_tables is not None and block_arange is not None:
|
||||
block_tables = (
|
||||
block_tables.unsqueeze(-1) + (block_arange * block_num).view(1, 1, -1).to(block_tables)
|
||||
).reshape(block_tables.shape[0], -1)
|
||||
return kv_cache, block_tables
|
||||
|
||||
def indexer_select_post_process(
|
||||
self,
|
||||
@@ -330,9 +347,11 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
q = q_li
|
||||
|
||||
key = kv_cache[2]
|
||||
block_table = attn_metadata.block_table
|
||||
assert attn_metadata.sfa_cp_metadata is not None
|
||||
key = self.gather_kv_cross_cp(key, attn_metadata.sfa_cp_metadata.valid_block_ids)
|
||||
block_table = attn_metadata.sfa_cp_metadata.block_table_cp
|
||||
block_arange = attn_metadata.sfa_cp_metadata.block_arange
|
||||
|
||||
key, block_table = self.gather_kv_cross_cp(key, block_table, block_arange)
|
||||
|
||||
if self.pcp_size == 1:
|
||||
return self._execute_indexer_select(
|
||||
@@ -353,7 +372,6 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
actual_seq_lengths_key[:num_decodes],
|
||||
block_table[:num_decodes],
|
||||
)
|
||||
|
||||
# prefill compute
|
||||
if num_prefills == 0:
|
||||
return decode_topk_indices
|
||||
|
||||
Reference in New Issue
Block a user