[feature] support pcp + mtp in full graph (#4572)
1. support pcp + mtp in full graph
2. pcp/dcp related mtp bugfix
3. support pcp + mtpx
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -67,6 +67,12 @@ class AscendPrefillContextParallelMetadata:
|
||||
|
||||
pcp_prefill_mask: torch.Tensor = None
|
||||
|
||||
# original query_lens before pcp split
|
||||
query_lens_pcp_full_cpu: torch.Tensor = None
|
||||
|
||||
# original max_query_len before pcp split
|
||||
max_query_len_pcp_full: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendCommonAttentionMetadata:
|
||||
@@ -189,6 +195,8 @@ def split_decodes_and_prefills(
|
||||
"""
|
||||
Assuming a reordered batch, finds the boundary between prefill and decode
|
||||
requests.
|
||||
While pcp > 1, query_lens is split across pcp ranks, so we pass in the
|
||||
original query_lens and max_query_len to distinguish prefills and decodes.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: AscendCommonAttentionMetadata object containing the
|
||||
@@ -201,7 +209,13 @@ def split_decodes_and_prefills(
|
||||
num_decode_tokens: The number of tokens in the decode requests.
|
||||
num_prefill_tokens: The number of tokens in the prefill requests.
|
||||
"""
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \
|
||||
if long_seq_metadata else None
|
||||
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \
|
||||
if long_seq_metadata else 0
|
||||
max_query_len = common_attn_metadata.max_query_len \
|
||||
if max_query_len_pcp_full == 0 else max_query_len_pcp_full
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
@@ -209,7 +223,8 @@ def split_decodes_and_prefills(
|
||||
if max_query_len <= decode_threshold:
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \
|
||||
if query_lens_pcp_full is None else query_lens_pcp_full
|
||||
is_prefill = query_lens > decode_threshold
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
Reference in New Issue
Block a user