[Feat] support basic pcp&dcp for qwen3next (#6091)

### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).

- vLLM version: v0.15.0
- vLLM main:
f176443446

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Bai Yongbin
2026-02-28 21:44:08 +08:00
committed by GitHub
parent 64fba51275
commit 9d09488b4a
16 changed files with 906 additions and 81 deletions

View File

@@ -383,6 +383,7 @@ class NPUModelRunner(GPUModelRunner):
self.intermediate_tensors: IntermediateTensors | None = None
self.reorder_batch_threshold: int | None = None
self.long_seq_metadata = None
self.query_lens: torch.Tensor | None = None
self.cpu_slot_mapping = None
@property
@@ -543,10 +544,12 @@ class NPUModelRunner(GPUModelRunner):
self,
scheduler_output: "SchedulerOutput",
num_scheduled_tokens: np.ndarray,
) -> tuple[torch.Tensor, SpecDecodeMetadata | None]:
) -> tuple[torch.Tensor, SpecDecodeMetadata | None, int]:
"""
:return: tuple[
logits_indices, spec_decode_metadata,
logits_indices,
spec_decode_metadata,
total_num_scheduled_tokens,
]
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@@ -610,11 +613,10 @@ class NPUModelRunner(GPUModelRunner):
if self.pcp_size > 1:
num_scheduled_tokens[:num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
num_scheduled_tokens[:num_reqs],
self.arange_np,
num_scheduled_tokens[:num_reqs], self.arange_np
)
# Re-update after PCP split sequences.
total_num_scheduled_tokens = sum(num_scheduled_tokens)
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
cu_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
positions_np = self.positions.np[:total_num_scheduled_tokens]
@@ -623,7 +625,11 @@ class NPUModelRunner(GPUModelRunner):
position_pcp[:total_num_scheduled_tokens],
out=positions_np,
)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
if self.pcp_size > 1 and self.pcp_manager.pcp_use_hybrid_attn:
assert self.pcp_manager.num_scheduled_tokens_padded is not None
self.query_lens = torch.from_numpy(self.pcp_manager.num_scheduled_tokens_padded)
else:
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
@@ -702,6 +708,8 @@ class NPUModelRunner(GPUModelRunner):
self.seq_lens.np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
self.seq_lens.copy_to_gpu()
# Fill unused with -1. Needed for reshape_and_cache in attention_cp
self.query_start_loc.gpu[num_reqs + 1 :].fill_(-1)
self.seq_lens.gpu[num_reqs:].fill_(0)
# Copy the tensors to the NPU.
@@ -732,6 +740,7 @@ class NPUModelRunner(GPUModelRunner):
num_tokens_np = np.array(num_tokens, dtype=np.int32)
base_num_reqs = self.input_batch.num_reqs
num_reqs = base_num_reqs
tokens_original = None
if self.pcp_size > 1:
# while pcp > 1, we need the original num_scheduled_tokens before split
# to calculate discard_requests_mask
@@ -758,7 +767,7 @@ class NPUModelRunner(GPUModelRunner):
num_draft_tokens = None
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
if self.use_cp:
logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens)
logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens, num_reqs, tokens_original)
logits_indices = logits_indices.pin_memory().to(self.device, non_blocking=True)
else:
logits_indices = self.query_start_loc.gpu[1 : num_reqs + 1] - 1
@@ -807,7 +816,11 @@ class NPUModelRunner(GPUModelRunner):
max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len
logits_indices = nn.functional.pad(logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0]))
return logits_indices, spec_decode_metadata
return (
logits_indices,
spec_decode_metadata,
total_num_scheduled_tokens,
)
def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens):
if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0):
@@ -1152,6 +1165,7 @@ class NPUModelRunner(GPUModelRunner):
(
logits_indices,
spec_decode_metadata,
total_num_scheduled_tokens,
) = self._prepare_inputs(
scheduler_output,
num_scheduled_tokens_np,
@@ -1220,7 +1234,9 @@ class NPUModelRunner(GPUModelRunner):
num_reqs_padded = self._pad_query_start_loc_for_fia(num_tokens_padded, num_reqs_padded, num_reqs)
(attn_metadata, spec_decode_common_attn_metadata) = self._build_attention_metadata(
num_tokens=num_tokens_unpadded,
num_tokens=num_tokens_unpadded
if not (self.use_cp and self.pcp_manager.pcp_use_hybrid_attn)
else total_num_scheduled_tokens,
num_tokens_padded=num_tokens_padded,
num_reqs=num_reqs,
num_reqs_padded=num_reqs_padded,
@@ -1240,7 +1256,13 @@ class NPUModelRunner(GPUModelRunner):
intermediate_tensors,
model_kwargs,
ec_connector_output,
) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors)
) = self._preprocess(
scheduler_output,
num_tokens_padded
if not (self.use_cp and self.pcp_manager.pcp_use_hybrid_attn)
else total_num_scheduled_tokens,
intermediate_tensors,
)
if self.dynamic_eplb:
self.eplb_updator.take_update_info_from_eplb_process()
@@ -1287,6 +1309,7 @@ class NPUModelRunner(GPUModelRunner):
batch_descriptor=batch_desc,
num_actual_tokens=scheduler_output.total_num_scheduled_tokens,
model_instance=self.model,
max_tokens_across_pcp=0 if self.pcp_size == 1 else self.pcp_manager.max_num_tokens_across_pcp,
skip_compiled=has_encoder_input,
),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
@@ -1922,11 +1945,16 @@ class NPUModelRunner(GPUModelRunner):
def _get_block_table_and_slot_mapping(kv_cache_gid: int):
assert num_reqs_padded is not None and num_tokens_padded is not None
kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec
maybe_pcp_full_tokens = (
num_tokens_padded
if self.pcp_size == 1
else num_tokens * self.pcp_size - sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs])
)
if self.pcp_size > 1:
total_num_pcp_pads = sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs])
if self.pcp_manager.pcp_use_hybrid_attn:
num_scheduled_tokens_padded = self.pcp_manager.num_scheduled_tokens_padded
assert num_scheduled_tokens_padded is not None
maybe_pcp_full_tokens = sum(num_scheduled_tokens_padded) * self.pcp_size - total_num_pcp_pads
else:
maybe_pcp_full_tokens = num_tokens * self.pcp_size - total_num_pcp_pads
else:
maybe_pcp_full_tokens = num_tokens_padded
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
blk_table_tensor = torch.zeros(
(num_reqs_padded, 1),

View File

@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from vllm.config import VllmConfig
from vllm.v1.utils import CpuGpuBuffer
@@ -110,6 +111,20 @@ class PCPManager:
self.query_lens_pcp_full = CpuGpuBuffer(
self.max_num_reqs, dtype=torch.int32, device=device, pin_memory=pin_memory
)
self.pcp_fa_query_idx = torch.zeros(
self.max_num_tokens + 2 * self.max_num_reqs, dtype=torch.int32, device=self.device
)
self.pcp_enter_fa_restore_idx = torch.zeros(
self.max_num_tokens + 2 * self.pcp_world_size * self.max_num_reqs, dtype=torch.int32, device=self.device
)
self.pcp_use_hybrid_attn = self.vllm_config.model_config.hf_config.model_type == "qwen3_next"
self.pcp_pads_logits_hybrid_attn = torch.ones(self.max_num_reqs, dtype=torch.int32) * (self.pcp_world_size - 1)
self.pcp_padded_tokens_fla = 0
self.pcp_padded_tokens_length = 0
self.num_scheduled_tokens_padded = None
self.max_num_tokens_across_pcp = 0
self.pcp_tokens_padded = None
def _get_cumsum_and_arange(
self,
@@ -184,9 +199,10 @@ class PCPManager:
Tuple (pcp_tokens, pcp_positions):
- pcp_tokens: number of tokens per request that this PCP rank will
actually process (after splitting / replication).
For hybrid-attention model: number of unpadded tokens
per requests
- pcp_positions: flattened positions for those tokens on this rank,
used to build the positions buffer for the model.
Example:
>>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp.
>>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7])
@@ -219,9 +235,10 @@ class PCPManager:
# cu_padded_tokens: cumulative sum of padded token counts,
# pcp_padded_arange: per-request arange flattened for padded tokens.
cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange(num_padded_scheduled_tokens, arange_np)
self.pcp_padded_tokens_length = pcp_padded_arange.shape[0]
# Build the mask that marks which positions in the padded allgather buffer
# correspond to real (unpadded) tokens.
self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = pcp_padded_arange < np.repeat(
self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length] = pcp_padded_arange < np.repeat(
num_scheduled_tokens, num_padded_scheduled_tokens
)
unpad_mask_decode = self.pcp_unpad_mask_cpu[: self.num_decode_tokens * self.pcp_world_size]
@@ -272,6 +289,9 @@ class PCPManager:
return positions
positions = get_current_rank_positions(0, self.pcp_world_rank)
padded_pos_start_loc = np.roll(cu_padded_tokens, 1)
padded_pos_start_loc[0] = 0
# Decode tokens are duplicated only after AG. But their positions are
# same without prefill context parallel.
if self.num_decode_reqs > 0:
@@ -279,35 +299,192 @@ class PCPManager:
num_scheduled_tokens[: self.num_decode_reqs], arange_np
)[1]
# Build the restore index used after allgather.
padded_pos_start_loc = np.roll(cu_padded_tokens, 1)
padded_pos_start_loc[0] = 0
all_positions_lst = [
get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size)
]
all_positions = np.concatenate(all_positions_lst)
self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort()
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
if self.pcp_use_hybrid_attn:
max_scheduled_prefill_tokens = 0
self.pcp_padded_tokens_fla = 0
if self.num_decode_reqs > 0:
num_padded_scheduled_tokens[: self.num_decode_reqs] = (
num_padded_scheduled_tokens[: self.num_decode_reqs] // self.pcp_world_size
)
self.total_pcp_padding_tokens_fla = 0
# have prefills
if self.num_reqs - self.num_decode_reqs > 0:
prefill_tokens_tensor = torch.Tensor(num_scheduled_tokens[self.num_decode_tokens :])
# [num_prefill_reqs, pcp_world_size, 1] [[3,2]] [[2,2,2,1],[2,1,1,1]]
num_prefill_tokens_allranks = (
self._get_cp_local_seq_lens(prefill_tokens_tensor, self.pcp_world_size, 1, 1).long().numpy()
)
# [3] [2] | [2,2] [2,1] [2,1] [1,1]
num_prefill_scheduled_tokens_linear = num_prefill_tokens_allranks[:, self.pcp_world_rank, 0]
num_padded_scheduled_tokens[self.num_decode_reqs :] = num_prefill_scheduled_tokens_linear
# [[3,5]] | [[0,0,0,0,0],[0,0,0,0,0]]
num_prefill_tokens_start_loc = np.zeros(
(self.num_reqs - self.num_decode_reqs, self.pcp_world_size + 1), dtype=np.int64
)
# [[0,3,5]] | [[0,2,4,6,7],[0,2,3,4,5]]
num_prefill_tokens_start_loc[:, 1:] = np.cumsum(num_prefill_tokens_allranks[..., 0], axis=-1)
# [0] [3] | [0,0] [2,2] [4,3] [6,4] [7,5]
num_prefill_tokens_cu_ranks = num_prefill_tokens_start_loc[:, self.pcp_world_rank]
# [0,1,2] [0,1] | [0,1,0,1] [0,1,0] [0,1,0] [0,0]
# -> [0,1,2] [3,4] | [0,1,0,1] [2,3,2] [4,5,3] [6,4]
_, positions_linear = self._get_cumsum_and_arange(num_padded_scheduled_tokens, arange_np)
positions_linear[self.num_decode_reqs :] = positions_linear[self.num_decode_reqs :] + np.repeat(
num_prefill_tokens_cu_ranks, num_prefill_scheduled_tokens_linear
)
self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs]
self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum()
return (
pcp_tokens[: self.num_reqs],
positions,
)
max_scheduled_prefill_tokens = num_prefill_tokens_allranks[:, 0, 0].sum()
num_prefill_tokens = num_scheduled_tokens[self.num_decode_reqs :].sum()
self.total_pcp_padding_tokens_fla = (
max_scheduled_prefill_tokens * self.pcp_world_size - num_prefill_tokens
)
self.pcp_padded_tokens_fla += max_scheduled_prefill_tokens - num_prefill_scheduled_tokens_linear.sum()
def get_logits_indices(self, cu_num_tokens: np.ndarray):
return torch.from_numpy(cu_num_tokens) * self.pcp_world_size - self.num_pcp_pads_cpu_tensor[: self.num_reqs] - 1
max_scheduled_tokens = max_scheduled_prefill_tokens + self.num_decode_tokens
enter_fa_prefill_restore_idx = None
if self.num_reqs - self.num_decode_reqs > 0:
# prefill reorder idx
# [[3,2]] [[2,2,2,1],[2,2,1,1],[1,1,1,1]]
num_prefill_tokens_allranks = num_prefill_tokens_allranks[..., 0]
# [0,1,2,0,1] [0,1,0,1,0,1,0,|0,1,0,1,0,0]
_, prefill_arange_allranks = self._get_cumsum_and_arange(
num_prefill_tokens_allranks.flatten(), arange_np
)
# [0,1] [0,1,2,3,0,1,2,3]
_, prefill_rank_offset = self._get_cumsum_and_arange(
np.ones(self.num_reqs - self.num_decode_reqs, dtype=np.int64) * self.pcp_world_size, arange_np
)
# [0,0,0,3,3] [0,M,2M,3M,0,M,2M,3M] -> [0,0,M,M,2M,2M,3M,0,0,M,M,2M,3M] + D
prefill_all_offset = (
np.repeat(prefill_rank_offset * max_scheduled_tokens, num_prefill_tokens_allranks.flatten())
+ self.num_decode_tokens
)
# [0,0,0,0,|2,2,2,1,|4,4,3,2] -> [0,0,0,0,0,0,0,|2,2,2,2,2,1,|4,4,3,2]
# [[0,0]] -> [0,0,0,0,0]
prefill_local_start_local = np.zeros_like(num_prefill_tokens_allranks)
prefill_local_start_local[1:, :] = np.cumsum(num_prefill_tokens_allranks, axis=0)[:-1, :]
prefill_local_offset = np.repeat(
prefill_local_start_local.flatten(), num_prefill_tokens_allranks.flatten()
)
prefill_all_offset = np.add(prefill_all_offset, prefill_local_offset)
# [0,1,2,3,4] [0,1,M,M+1,2M,2M+1,3M,0,1,M,M+1,2M,3M]
enter_fa_prefill_restore_idx = np.add(prefill_all_offset, prefill_arange_allranks)
else:
_, positions_linear = self._get_cumsum_and_arange(num_padded_scheduled_tokens, arange_np)
# decode reorder idx
enter_fa_decode_restore_idx = None
if self.num_decode_reqs > 0:
# [0,1,2], [4,4,4] -> [0,0,0,0,1,1,1,1,2,2,2,2]
num_decode_pcp_size = np.ones(self.num_decode_reqs, dtype=np.int64) * self.pcp_world_size
decode_reqs_offset = np.repeat(np.arange(self.num_decode_reqs, dtype=np.int64), num_decode_pcp_size)
decode_ranks_offset = (
self._get_cumsum_and_arange(num_decode_pcp_size, arange_np)[1] * max_scheduled_tokens
)
enter_fa_decode_restore_idx = np.add(decode_reqs_offset, decode_ranks_offset)
if enter_fa_decode_restore_idx is not None and enter_fa_prefill_restore_idx is not None:
pcp_enter_fa_restore_idx = torch.from_numpy(
np.concatenate([enter_fa_decode_restore_idx, enter_fa_prefill_restore_idx])
)
elif enter_fa_decode_restore_idx is not None:
pcp_enter_fa_restore_idx = torch.from_numpy(enter_fa_decode_restore_idx)
elif enter_fa_prefill_restore_idx is not None:
pcp_enter_fa_restore_idx = torch.from_numpy(enter_fa_prefill_restore_idx)
self.pcp_enter_fa_restore_idx[: pcp_enter_fa_restore_idx.shape[0]].copy_(
pcp_enter_fa_restore_idx.long(), non_blocking=True
)
if self.num_reqs > self.num_decode_reqs:
all_positions_prefill = [
get_current_rank_positions(padded_pos_start_loc, rank_i)[self.num_decode_tokens :]
- self.num_decode_tokens * self.pcp_world_size
for rank_i in range(self.pcp_world_size)
]
all_positions_prefill_tensor = torch.from_numpy(np.concatenate(all_positions_prefill))
all_enter_fla_restore_idx = all_positions_prefill_tensor.float().argsort()
unpad_mask_prefill = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length][
self.num_decode_reqs * self.pcp_world_size :
]
# [0] | [0,7]
ori_tokens_start_loc = np.roll(np.cumsum(num_scheduled_tokens[self.num_decode_tokens :]), 1)
ori_tokens_start_loc[0] = 0
# [0,1,2] [3,4] | [0,1,7,8] [2,3,9] [4,5,10] [6,11]
enter_fla_scatter_idx = positions_linear[self.num_decode_reqs :] + np.repeat(
ori_tokens_start_loc, num_prefill_scheduled_tokens_linear
)
enter_fla_restore_idx = torch.index_select(
all_enter_fla_restore_idx[unpad_mask_prefill], 0, torch.from_numpy(enter_fla_scatter_idx)
)
self.pcp_allgather_restore_idx.gpu[: enter_fla_restore_idx.shape[0]].copy_(
enter_fla_restore_idx.long(), non_blocking=True
)
positions_prefill = all_positions_prefill[self.pcp_world_rank]
pcp_fa_query_idx_tensor = torch.from_numpy(positions_prefill)
self.pcp_fa_query_idx[: pcp_fa_query_idx_tensor.shape[0]].copy_(
pcp_fa_query_idx_tensor.long(), non_blocking=True
)
self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs]
self.total_num_sampled_tokens_pcp = num_scheduled_tokens[: self.num_reqs].sum()
self.max_num_tokens_across_pcp = max_scheduled_tokens
self.pcp_tokens_padded = pcp_tokens[: self.num_reqs]
self.num_scheduled_tokens_padded = np.array(self.pcp_tokens_padded, dtype=np.int32)
return num_padded_scheduled_tokens, positions_linear
else:
# Build the restore index used after allgather.
all_positions_lst = [
get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size)
]
all_positions = np.concatenate(all_positions_lst)
self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort()
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs]
self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum()
return pcp_tokens[: self.num_reqs], positions
def get_logits_indices(
self,
cu_num_tokens: np.ndarray,
num_reqs: int,
tokens_original: list[int] | None = None,
):
if not self.pcp_use_hybrid_attn or tokens_original is None:
logits_indices = (
torch.from_numpy(cu_num_tokens) * self.pcp_world_size
- self.num_pcp_pads_cpu_tensor[: self.num_reqs]
- 1
)
else:
tokens_original_tensor = torch.tensor(tokens_original, dtype=torch.int32)
num_prefill_reqs = (tokens_original_tensor > self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
decode_pads = self.pcp_pads_logits_hybrid_attn[:num_decode_reqs]
pad_len = tokens_original_tensor.shape[0] - num_decode_reqs
tokens_logits = tokens_original_tensor + F.pad(decode_pads, (0, pad_len), value=0)
logits_indices = torch.cumsum(tokens_logits, dim=0) - 1
return logits_indices
def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int, slot_mapping: torch.Tensor):
# After pcp allgather and restore, there are padded tokens in kv,
# so we need pad slotmapping for alignment.
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: num_tokens_padded * self.pcp_world_size]
if self.pcp_use_hybrid_attn:
assert self.num_scheduled_tokens_padded is not None
num_tokens = self.num_scheduled_tokens_padded.sum()
pcp_padded_slot_mapping = (
self.pcp_padded_slot_mapping[: num_tokens_padded * self.pcp_world_size]
if not self.pcp_use_hybrid_attn
else self.pcp_padded_slot_mapping[: num_tokens * self.pcp_world_size]
)
cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[: num_tokens * self.pcp_world_size]
pcp_padded_slot_mapping.fill_(-1)
pcp_padded_slot_mapping[: num_tokens * self.pcp_world_size][cp_unpad_mask] = slot_mapping
return pcp_padded_slot_mapping
if self.pcp_use_hybrid_attn:
return pcp_padded_slot_mapping.clone()
else:
return pcp_padded_slot_mapping
def get_restore_hidden_states(
self,
@@ -317,16 +494,25 @@ class PCPManager:
# ignores the padding from CUDA Graph.
from vllm.distributed.parallel_state import get_pcp_group
hidden_states = get_pcp_group().all_gather(
hidden_states[: self.num_actual_tokens_pcp_padded // self.pcp_world_size],
0,
)
restore_idx = self.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]]
return torch.index_select(
hidden_states,
0,
restore_idx,
)
if not self.pcp_use_hybrid_attn:
hidden_states = get_pcp_group().all_gather(
hidden_states[: self.num_actual_tokens_pcp_padded // self.pcp_world_size],
0,
)
restore_idx = self.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]]
return torch.index_select(
hidden_states,
0,
restore_idx,
)
else:
if self.pcp_padded_tokens_fla > 0:
hidden_states = F.pad(
hidden_states, pad=(0, 0, 0, self.pcp_padded_tokens_fla), mode="constant", value=0
)
hidden_states = get_pcp_group().all_gather(hidden_states.contiguous(), dim=0)
restore_idx = self.pcp_enter_fa_restore_idx[: hidden_states.shape[0] - self.total_pcp_padding_tokens_fla]
return torch.index_select(hidden_states, 0, restore_idx)
def generate_pcp_mtp_input(
self,
@@ -528,6 +714,15 @@ class PCPManager:
):
from vllm_ascend.attention.utils import AscendPrefillContextParallelMetadata
if self.pcp_world_size > 1 and self.pcp_use_hybrid_attn:
assert self.num_scheduled_tokens_padded is not None
total_num_scheduled_tokens = self.num_scheduled_tokens_padded.sum()
query_lens_new = (
self.query_lens_pcp_full.cpu[:num_reqs]
if self.pcp_world_size > 1 and self.speculative_config
else query_lens
)
num_decodes = (query_lens_new <= self.decode_threshold).sum().item()
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
@@ -599,10 +794,13 @@ class PCPManager:
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length]
long_seq_metadata = AscendPrefillContextParallelMetadata(
pcp_use_hybrid_attn=self.pcp_use_hybrid_attn,
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(),
pcp_unpad_mask=torch.from_numpy(pcp_unpad_mask),
pcp_padded_tokens_fla=self.pcp_padded_tokens_fla,
)
if ori_query_lens_cpu is not None:
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
@@ -703,9 +901,20 @@ class PCPManager:
"head_attn_nomask_seqlens": head_attn_nomask_seqlens,
"tail_attn_nomask_seqlens": tail_attn_nomask_seqlens,
}
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[
:num_actual_tokens_pcp_padded
]
if not self.pcp_use_hybrid_attn:
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[
:num_actual_tokens_pcp_padded
]
else:
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[
: num_scheduled_tokens.sum() - num_decodes
]
long_seq_metadata.pcp_fa_query_idx = self.pcp_fa_query_idx[
: num_actual_tokens_pcp_padded // self.pcp_world_size - num_decodes
]
long_seq_metadata.pcp_enter_fa_restore_idx = self.pcp_enter_fa_restore_idx[
: pcp_unpad_mask.sum() + num_decodes * (self.pcp_world_size - 1)
]
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
long_seq_metadata.q_full_idx = self.q_full_idx