[feat]ds3.2 pcp support mtp and chunkprefill (#6917)
### What this PR does / why we need it?
ds3.2 pcp supports the combination of MTP and chunkprefill features.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -246,36 +246,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
try:
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
except Exception:
|
||||
self.dcp_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.pcp_size = 1
|
||||
self.pcp_rank = 0
|
||||
if self.pcp_size > 1:
|
||||
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||||
max_buffer_num_tokens = self.max_num_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
|
||||
self.pcp_manager = PCPManager(
|
||||
self.pcp_size,
|
||||
self.pcp_rank,
|
||||
self.dcp_size,
|
||||
self.dcp_rank,
|
||||
max_buffer_num_tokens,
|
||||
self.max_num_reqs,
|
||||
self.device,
|
||||
self.vllm_config,
|
||||
self.use_async_scheduling,
|
||||
self.pin_memory,
|
||||
)
|
||||
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
|
||||
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
|
||||
|
||||
self.sampler = AscendSampler()
|
||||
self.attn_state: AscendAttentionState | None = None
|
||||
|
||||
@@ -310,6 +281,38 @@ class NPUModelRunner(GPUModelRunner):
|
||||
use_mm_prefix=self.model_config is not None and self.model_config.is_mm_prefix_lm,
|
||||
)
|
||||
|
||||
try:
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
except Exception:
|
||||
self.dcp_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.pcp_size = 1
|
||||
self.pcp_rank = 0
|
||||
if self.pcp_size > 1:
|
||||
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||||
max_buffer_num_tokens = self.max_num_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
|
||||
self.pcp_manager = PCPManager(
|
||||
self.pcp_size,
|
||||
self.pcp_rank,
|
||||
self.dcp_size,
|
||||
self.dcp_rank,
|
||||
max_buffer_num_tokens,
|
||||
self.max_num_reqs,
|
||||
self.device,
|
||||
self.vllm_config,
|
||||
self.use_async_scheduling,
|
||||
self.pin_memory,
|
||||
self.use_sparse,
|
||||
)
|
||||
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
|
||||
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
|
||||
|
||||
self._set_up_drafter()
|
||||
|
||||
# kv role
|
||||
|
||||
Reference in New Issue
Block a user