[feat]ds3.2 pcp support mtp and chunkprefill (#6917)
### What this PR does / why we need it?
ds3.2 pcp supports the combination of MTP and chunkprefill features.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -246,36 +246,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
try:
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
except Exception:
|
||||
self.dcp_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.pcp_size = 1
|
||||
self.pcp_rank = 0
|
||||
if self.pcp_size > 1:
|
||||
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||||
max_buffer_num_tokens = self.max_num_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
|
||||
self.pcp_manager = PCPManager(
|
||||
self.pcp_size,
|
||||
self.pcp_rank,
|
||||
self.dcp_size,
|
||||
self.dcp_rank,
|
||||
max_buffer_num_tokens,
|
||||
self.max_num_reqs,
|
||||
self.device,
|
||||
self.vllm_config,
|
||||
self.use_async_scheduling,
|
||||
self.pin_memory,
|
||||
)
|
||||
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
|
||||
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
|
||||
|
||||
self.sampler = AscendSampler()
|
||||
self.attn_state: AscendAttentionState | None = None
|
||||
|
||||
@@ -310,6 +281,38 @@ class NPUModelRunner(GPUModelRunner):
|
||||
use_mm_prefix=self.model_config is not None and self.model_config.is_mm_prefix_lm,
|
||||
)
|
||||
|
||||
try:
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
except Exception:
|
||||
self.dcp_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.pcp_size = 1
|
||||
self.pcp_rank = 0
|
||||
if self.pcp_size > 1:
|
||||
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||||
max_buffer_num_tokens = self.max_num_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
|
||||
self.pcp_manager = PCPManager(
|
||||
self.pcp_size,
|
||||
self.pcp_rank,
|
||||
self.dcp_size,
|
||||
self.dcp_rank,
|
||||
max_buffer_num_tokens,
|
||||
self.max_num_reqs,
|
||||
self.device,
|
||||
self.vllm_config,
|
||||
self.use_async_scheduling,
|
||||
self.pin_memory,
|
||||
self.use_sparse,
|
||||
)
|
||||
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
|
||||
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
|
||||
|
||||
self._set_up_drafter()
|
||||
|
||||
# kv role
|
||||
|
||||
@@ -56,6 +56,7 @@ class PCPManager:
|
||||
vllm_config: VllmConfig,
|
||||
use_async_scheduling: bool,
|
||||
pin_memory: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> None:
|
||||
self.pcp_world_size = pcp_world_size
|
||||
self.pcp_world_rank = pcp_rank
|
||||
@@ -97,6 +98,7 @@ class PCPManager:
|
||||
+ self.pcp_world_size * self.dcp_world_size * self.max_num_reqs
|
||||
)
|
||||
)
|
||||
self.use_sparse = use_sparse
|
||||
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
|
||||
self.input_ids_pcp_full = CpuGpuBuffer(
|
||||
self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory
|
||||
@@ -784,16 +786,19 @@ class PCPManager:
|
||||
num_prefill_reqs = self.num_prefill_reqs
|
||||
num_decode_reqs = self.num_decode_reqs
|
||||
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
|
||||
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
|
||||
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
|
||||
)
|
||||
block_table_tensor[:num_decode_reqs_flatten].copy_(
|
||||
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
|
||||
)
|
||||
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
|
||||
if num_reqs_padded > num_reqs:
|
||||
pad_size = num_reqs_padded - num_reqs
|
||||
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
|
||||
if not self.use_sparse:
|
||||
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
|
||||
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
|
||||
)
|
||||
block_table_tensor[:num_decode_reqs_flatten].copy_(
|
||||
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
|
||||
)
|
||||
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
|
||||
if num_reqs_padded > num_reqs:
|
||||
pad_size = num_reqs_padded - num_reqs
|
||||
ori_query_lens_cpu[-pad_size:] = torch.full(
|
||||
[pad_size], ori_query_lens_cpu[-pad_size - 1].item()
|
||||
)
|
||||
pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length]
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
pcp_use_hybrid_attn=self.pcp_use_hybrid_attn,
|
||||
|
||||
Reference in New Issue
Block a user