support cp&dcp (#3260)
### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
This commit is contained in:
@@ -5,6 +5,11 @@ import torch
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import get_pcp_group
|
||||
|
||||
|
||||
class BlockTable:
|
||||
|
||||
@@ -15,7 +20,8 @@ class BlockTable:
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
kernel_sizes: Union[list[int], None] = None):
|
||||
kernel_sizes: Union[list[int], None] = None,
|
||||
cp_kv_cache_interleave_size: int = 1):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@@ -80,13 +86,20 @@ class BlockTable:
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
try:
|
||||
self.pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_world_size > 1 else 0
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.pcp_world_size = 1
|
||||
self.pcp_rank = 0
|
||||
self.kernel_sizes = kernel_sizes
|
||||
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
@@ -132,14 +145,14 @@ class BlockTable:
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
if self.dcp_world_size * self.pcp_world_size > 1:
|
||||
# Note(hc): The DCP implement store kvcache with an interleave
|
||||
# style, the kvcache for the token whose token_idx is i is
|
||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||
|
||||
# Use a "virtual block" which equals to world_size * block_size
|
||||
# for block_table_indices calculation.
|
||||
virtual_block_size = self.block_size * self.dcp_world_size
|
||||
virtual_block_size = self.block_size * self.dcp_world_size * self.pcp_world_size
|
||||
|
||||
# IMPORTANT: In hybrid mode, positions are in logical block space,
|
||||
# but we need to map them to the correct logical block table indices
|
||||
@@ -157,9 +170,14 @@ class BlockTable:
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
|
||||
self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank
|
||||
mask = (virtual_block_offsets // self.cp_kv_cache_interleave_size %
|
||||
(self.dcp_world_size *
|
||||
self.pcp_world_size) == self.current_rank)
|
||||
# Calculate local block_offsets
|
||||
block_offsets = virtual_block_offsets // self.dcp_world_size
|
||||
block_offsets = virtual_block_offsets \
|
||||
// (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size) \
|
||||
* self.cp_kv_cache_interleave_size + virtual_block_offsets % self.cp_kv_cache_interleave_size
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
@@ -242,16 +260,20 @@ class MultiGroupBlockTable:
|
||||
device: torch.device,
|
||||
block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
kernel_sizes: Optional[list[list[int]]] = None) -> None:
|
||||
kernel_sizes: Optional[list[list[int]]] = None,
|
||||
cp_kv_cache_interleave_size: int = 1) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
# must be multiplied by dcp_world_size.
|
||||
try:
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
cp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
cp_world_size = 1
|
||||
|
||||
if kernel_sizes is None:
|
||||
kernel_sizes = [[0]] * len(block_sizes)
|
||||
@@ -267,9 +289,12 @@ class MultiGroupBlockTable:
|
||||
self.block_tables = [
|
||||
BlockTable(
|
||||
block_size, max_num_reqs,
|
||||
max(cdiv(max_model_len, block_size * dcp_world_size),
|
||||
max(
|
||||
cdiv(max_model_len,
|
||||
block_size * dcp_world_size * cp_world_size),
|
||||
1 + num_speculative_tokens), max_num_batched_tokens,
|
||||
pin_memory, device, kernel_size_list)
|
||||
pin_memory, device, kernel_size_list,
|
||||
cp_kv_cache_interleave_size)
|
||||
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
|
||||
]
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
||||
get_tp_group,
|
||||
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
|
||||
get_pp_group, get_tp_group,
|
||||
is_global_first_rank)
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import logger
|
||||
@@ -107,7 +107,8 @@ from vllm_ascend.ascend_forward_context import (MoECommType,
|
||||
set_ascend_forward_context)
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
AscendPrefillContextParallelMetadata)
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
set_graph_params,
|
||||
update_attn_params,
|
||||
@@ -132,9 +133,16 @@ from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendSocVersion, ProfileExecuteDuration,
|
||||
enable_sp, get_ascend_soc_version, is_310p,
|
||||
is_enable_nz, lmhead_tp_enable)
|
||||
is_enable_nz, lmhead_tp_enable,
|
||||
prefill_context_parallel_enable)
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import get_pcp_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr # type: ignore[import-untyped]
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -260,6 +268,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
decode_max_num_seqs)
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
self.device = device
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
|
||||
self.prefetch_stream = torch.npu.Stream(device=device)
|
||||
@@ -320,7 +334,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.block_size,
|
||||
use_mla=self.model_config.use_mla,
|
||||
use_sparse=self.use_sparse)
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
if self.pcp_size > 1:
|
||||
self.attn_mask_builder = None
|
||||
elif torch.version.cann.startswith("8.3"):
|
||||
self.attn_mask_builder = AttentionMaskBuilder(
|
||||
self.scheduler_config.max_num_batched_tokens, self.dtype,
|
||||
self.device)
|
||||
@@ -454,6 +470,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
self.pcp_allgather_restore_idx = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32)
|
||||
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
self.use_aclgraph = self._use_aclgraph()
|
||||
self.aclgraph_batch_sizes = list(
|
||||
@@ -525,6 +548,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config.model_config.logits_processors),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
kernel_block_sizes=[[self.vllm_config.cache_config.block_size]],
|
||||
cp_kv_cache_interleave_size=self.parallel_config.
|
||||
cp_kv_cache_interleave_size
|
||||
if prefill_context_parallel_enable() else 1,
|
||||
)
|
||||
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int64)
|
||||
@@ -890,12 +916,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _make_attention_mask(self, seq_lens, position,
|
||||
attn_state) -> torch.Tensor:
|
||||
if self.pcp_size > 1:
|
||||
return None
|
||||
if self.attn_mask_builder is None:
|
||||
raise ValueError("Attn mask builder is None")
|
||||
# Pooling situation.
|
||||
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||
# Chunk Prefill situation.
|
||||
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
if self.dcp_size > 1:
|
||||
max_seq_len = max(seq_lens.max().item(), 0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
elif torch.version.cann.startswith("8.3"):
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
else:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||
@@ -945,7 +979,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
src_end = num_computed_tokens + prompt_part_len
|
||||
|
||||
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
||||
req.mrope_positions[:,src_start:src_end]
|
||||
req.mrope_positions[:, src_start:src_end]
|
||||
|
||||
mrope_pos_ptr += prompt_part_len
|
||||
|
||||
@@ -1219,7 +1253,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = num_scheduled_tokens.max()
|
||||
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens)
|
||||
_, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
|
||||
positions_np = np.add(
|
||||
self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
)
|
||||
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices, positions_np)
|
||||
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
|
||||
tokens)
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
# update total_num_scheduled_tokens
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
|
||||
self.input_batch.block_table.commit_slot_mapping(
|
||||
total_num_scheduled_tokens)
|
||||
|
||||
total_num_pcp_pads = sum(self.num_pcp_pads)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
num_valid_tokens = np.array([
|
||||
num_tokens -
|
||||
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
|
||||
@@ -1284,10 +1338,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
|
||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
out=positions_np)
|
||||
if self.pcp_size > 1:
|
||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
position_pcp[:total_num_scheduled_tokens],
|
||||
out=positions_np)
|
||||
else:
|
||||
self.positions_np[:total_num_scheduled_tokens] = positions_np
|
||||
|
||||
# Calculate M-RoPE positions.
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
@@ -1315,13 +1372,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch.from_numpy(token_indices),
|
||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare some information for building Attention-Metadata
|
||||
# Compute and commit slot mapping
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices, positions_np)
|
||||
self.input_batch.block_table.commit_slot_mapping(
|
||||
total_num_scheduled_tokens)
|
||||
|
||||
self.query_start_loc_np[0] = 0
|
||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
self.query_start_loc[:num_reqs + 1].copy_(
|
||||
@@ -1351,6 +1401,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions_cpu = self.positions_cpu[:num_input_tokens]
|
||||
positions = self.positions[:num_input_tokens]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
||||
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens)
|
||||
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
||||
@@ -1428,9 +1479,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
spec_decode_metadata = None
|
||||
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
|
||||
self.device, non_blocking=True)
|
||||
logits_indices = torch.from_numpy(
|
||||
cu_num_tokens
|
||||
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
|
||||
logits_indices = logits_indices.to(self.device, non_blocking=True)
|
||||
else:
|
||||
# pcp not supported now
|
||||
assert self.pcp_size == 1
|
||||
# Get the number of draft tokens for each request.
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
@@ -1458,10 +1513,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
|
||||
# prepare pcp meta data
|
||||
long_seq_metadata = self._generate_pcp_metadata(
|
||||
total_num_scheduled_tokens, seq_lens_cpu)
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
slot_mapping_size = (total_num_scheduled_tokens
|
||||
if self.pcp_size == 1 else
|
||||
total_num_scheduled_tokens * self.pcp_size -
|
||||
total_num_pcp_pads)
|
||||
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||||
EncoderOnlyAttentionSpec):
|
||||
# Encoder-only layers do not have KV cache, so we need to
|
||||
@@ -1479,13 +1541,24 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()
|
||||
slot_mapping = blk_table.slot_mapping_cpu[:
|
||||
total_num_scheduled_tokens]
|
||||
self.slot_mapping[:total_num_scheduled_tokens].copy_(
|
||||
slot_mapping[:total_num_scheduled_tokens],
|
||||
slot_mapping = blk_table.slot_mapping_cpu[:slot_mapping_size]
|
||||
self.slot_mapping[:slot_mapping_size].copy_(
|
||||
slot_mapping[:slot_mapping_size],
|
||||
non_blocking=True,
|
||||
)
|
||||
self.slot_mapping[total_num_scheduled_tokens:].fill_(0)
|
||||
self.slot_mapping[slot_mapping_size:].fill_(0)
|
||||
if self.pcp_size > 1:
|
||||
assert pcp_unpad_mask is not None
|
||||
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
|
||||
pcp_unpad_mask
|
||||
.
|
||||
shape[
|
||||
0]]
|
||||
pcp_padded_slot_mapping.fill_(-1)
|
||||
pcp_padded_slot_mapping[
|
||||
pcp_unpad_mask] = self.slot_mapping[:slot_mapping_size]
|
||||
self.slot_mapping[:long_seq_metadata.
|
||||
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
||||
|
||||
# Make AscendCommonAttentionMetadata
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
@@ -1494,7 +1567,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
num_actual_tokens=slot_mapping_size,
|
||||
num_input_tokens=num_input_tokens,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
# TODO: change this to the right block table for linear attn
|
||||
@@ -1512,6 +1585,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
cos=self.cos,
|
||||
sin=self.sin,
|
||||
prefill_context_parallel_metadata=long_seq_metadata,
|
||||
)
|
||||
|
||||
if self.speculative_config and \
|
||||
@@ -1587,6 +1661,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
pad_size = get_forward_context().pad_size
|
||||
if pad_size > 0:
|
||||
hidden_states = hidden_states[:-pad_size, :]
|
||||
|
||||
if self.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
|
||||
hidden_states = torch.index_select(
|
||||
hidden_states, 0,
|
||||
self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
|
||||
return hidden_states
|
||||
|
||||
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
||||
@@ -2485,8 +2565,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
def profile_run(self) -> None:
|
||||
# Trigger compilation for general shape.
|
||||
with self.set_in_profile_run():
|
||||
hidden_states = self._dummy_run(self.max_num_tokens,
|
||||
with_prefill=True)
|
||||
hidden_states = self._dummy_run(
|
||||
self.max_num_tokens //
|
||||
self.pcp_size if self.pcp_size > 1 else self.max_num_tokens,
|
||||
with_prefill=True)
|
||||
# MC2 will consume additional NPU memory.
|
||||
# Therefore, we need to run the MC2 path once here to complete its initialization,
|
||||
# allowing vLLM to correctly estimate the maximum memory required.
|
||||
@@ -3620,3 +3702,236 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||
return False
|
||||
|
||||
def _update_tokens_for_pcp(self, tokens):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
|
||||
if not self.pcp_size > 1:
|
||||
return tokens, None, None
|
||||
tokens = np.array(tokens, dtype=np.int32)
|
||||
num_decode_reqs = sum(
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] >=
|
||||
self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
num_padded_scheduled_tokens = np.ceil(
|
||||
tokens /
|
||||
(2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size)
|
||||
num_padded_scheduled_tokens[:num_decode_reqs] = self.pcp_size
|
||||
self.num_pcp_pads = num_padded_scheduled_tokens - tokens
|
||||
cu_padded_tokens, pcp_padded_arange = \
|
||||
self._get_cumsum_and_arange(num_padded_scheduled_tokens)
|
||||
unpad_mask = torch.from_numpy(
|
||||
pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens))
|
||||
|
||||
pcp_tokens = num_padded_scheduled_tokens // self.pcp_size
|
||||
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
|
||||
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens)
|
||||
_, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
|
||||
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
|
||||
pcp_tokens)
|
||||
|
||||
def get_current_rank_positions(cu_tokens, rank):
|
||||
positions_start_loc = np.zeros_like(cu_tokens)
|
||||
positions_start_loc[1:] = cu_tokens[:-1]
|
||||
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
|
||||
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
|
||||
tail_start_loc = positions_start_loc + \
|
||||
(2 * self.pcp_size - rank - 1) * pcp_chunk_sizes
|
||||
positions[pcp_head_chunk_mask] = pcp_chunk_arange + \
|
||||
np.repeat(head_start_loc, pcp_chunk_sizes)
|
||||
# Decode reqs do not have tail chunks.
|
||||
positions[~pcp_head_chunk_mask] = \
|
||||
pcp_chunk_arange[num_decode_reqs:] + \
|
||||
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_reqs:]
|
||||
return positions
|
||||
|
||||
positions = get_current_rank_positions(
|
||||
np.zeros(num_reqs, dtype=np.int32), self.pcp_rank)
|
||||
# Decode tokens are duplicate and their positions always be 0.
|
||||
positions[:num_decode_reqs] = 0
|
||||
|
||||
all_positions = [
|
||||
get_current_rank_positions(cu_padded_tokens, rank_i)
|
||||
for rank_i in range(self.pcp_size)
|
||||
]
|
||||
all_positions_tensor = torch.from_numpy(np.concatenate(all_positions))
|
||||
self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_(
|
||||
all_positions_tensor.float().argsort().long(), non_blocking=True)
|
||||
pcp_tokens[:num_decode_reqs] = 1
|
||||
return pcp_tokens, positions, unpad_mask
|
||||
|
||||
def _get_pcp_local_seq_lens(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
pcp_world_size: int = 1,
|
||||
dcp_world_size: int = 1,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""While using pcp or dcp, kv_cache size stored on each rank may be different,
|
||||
use this function to calculate split decode seq_lens of each (p/d)cp rank.
|
||||
"""
|
||||
num_requests = seq_lens.size(0)
|
||||
total_world_size = pcp_world_size * dcp_world_size
|
||||
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
|
||||
rank_offsets = (torch.arange(total_world_size,
|
||||
dtype=torch.int32).unsqueeze(0).repeat(
|
||||
num_requests, 1))
|
||||
base = (seq_lens_tiled // cp_kv_cache_interleave_size //
|
||||
total_world_size * cp_kv_cache_interleave_size)
|
||||
remainder = seq_lens_tiled - base * total_world_size
|
||||
remainder = torch.clip(
|
||||
remainder - rank_offsets * cp_kv_cache_interleave_size,
|
||||
0,
|
||||
cp_kv_cache_interleave_size,
|
||||
)
|
||||
dcp_local_seq_lens = (base + remainder).reshape(
|
||||
[-1, pcp_world_size, dcp_world_size])
|
||||
return dcp_local_seq_lens
|
||||
|
||||
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||||
num_prefills = num_reqs - num_decodes
|
||||
long_seq_metadata = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_computed_tokens_of_pcp_dcp=self._get_pcp_local_seq_lens(
|
||||
seq_lens,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
).numpy(),
|
||||
)
|
||||
if self.pcp_size > 1:
|
||||
q_head_idx, q_tail_idx = [], []
|
||||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||||
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
|
||||
chunk_seqlens = []
|
||||
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
|
||||
q_req_offset = 0
|
||||
kv_req_offset = 0
|
||||
q_head_chunk_id = self.pcp_rank
|
||||
q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
if i < num_decodes:
|
||||
continue
|
||||
chunk_len = seq_len // 2
|
||||
chunk_seqlens.append(chunk_len)
|
||||
q_head_idx.extend(
|
||||
list(range(q_req_offset, q_req_offset + chunk_len)))
|
||||
kv_with_q_head_nomask_idx.extend(
|
||||
list(
|
||||
range(kv_req_offset, kv_req_offset +
|
||||
chunk_len * q_head_chunk_id)))
|
||||
kv_with_q_head_mask_idx.extend(
|
||||
list(
|
||||
range(
|
||||
kv_req_offset + chunk_len * q_head_chunk_id,
|
||||
kv_req_offset + chunk_len *
|
||||
(q_head_chunk_id + 1))))
|
||||
kv_with_q_head_nomask_seqlens.append(chunk_len *
|
||||
q_head_chunk_id)
|
||||
|
||||
q_tail_idx.extend(
|
||||
list(
|
||||
range(q_req_offset + chunk_len,
|
||||
q_req_offset + chunk_len * 2)))
|
||||
kv_with_q_tail_nomask_idx.extend(
|
||||
list(
|
||||
range(kv_req_offset, kv_req_offset +
|
||||
chunk_len * q_tail_chunk_id)))
|
||||
kv_with_q_tail_mask_idx.extend(
|
||||
list(
|
||||
range(
|
||||
kv_req_offset + chunk_len * q_tail_chunk_id,
|
||||
kv_req_offset + chunk_len *
|
||||
(q_tail_chunk_id + 1))))
|
||||
kv_with_q_tail_nomask_seqlens.append(chunk_len *
|
||||
q_tail_chunk_id)
|
||||
|
||||
q_req_offset += seq_len
|
||||
kv_req_offset += seq_len * self.pcp_size
|
||||
|
||||
# Convert lists to tensors and move to device
|
||||
def _list_to_tensor(lst, device, dtype=torch.int32):
|
||||
tensor_npu = torch.zeros(len(lst),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
|
||||
non_blocking=True)
|
||||
return tensor_npu
|
||||
|
||||
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
|
||||
q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device)
|
||||
self.q_head_idx_tensor = q_head_idx_tensor
|
||||
self.q_tail_idx_tensor = q_tail_idx_tensor
|
||||
|
||||
q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor])
|
||||
q_full_idx = q_full_idx.to(torch.float32).argsort().to(
|
||||
torch.int32)
|
||||
self.q_full_idx = q_full_idx
|
||||
|
||||
self.kv_idx_names = {
|
||||
'kv_with_q_head_nomask_idx_tensor':
|
||||
kv_with_q_head_nomask_idx,
|
||||
'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx,
|
||||
'kv_with_q_tail_nomask_idx_tensor':
|
||||
kv_with_q_tail_nomask_idx,
|
||||
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
|
||||
}
|
||||
for key, value in self.kv_idx_names.items():
|
||||
tensor_npu = _list_to_tensor(value, self.device)
|
||||
self.kv_idx_names[key] = tensor_npu
|
||||
|
||||
attn_mask_seqlens = torch.tensor(
|
||||
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
|
||||
head_attn_nomask_seqlens = torch.tensor(
|
||||
[chunk_seqlens, kv_with_q_head_nomask_seqlens],
|
||||
dtype=torch.int32)
|
||||
tail_attn_nomask_seqlens = torch.tensor(
|
||||
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
|
||||
dtype=torch.int32)
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
pcp_prefill_mask = torch.triu(
|
||||
torch.ones(512,
|
||||
512,
|
||||
device=self.device,
|
||||
dtype=self.dtype), 1)
|
||||
else:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
pcp_prefill_mask = torch.triu(
|
||||
torch.full((num_prefills, max_seq_len, max_seq_len),
|
||||
True,
|
||||
device=self.device,
|
||||
dtype=torch.bool), 1)
|
||||
|
||||
self.extra_long_seq_kwargs = {
|
||||
'attn_mask_seqlens': attn_mask_seqlens,
|
||||
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
|
||||
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
|
||||
'pcp_prefill_mask': pcp_prefill_mask
|
||||
}
|
||||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
||||
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
||||
long_seq_metadata.q_full_idx = self.q_full_idx
|
||||
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_head_nomask_idx_tensor']
|
||||
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_head_mask_idx_tensor']
|
||||
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_tail_nomask_idx_tensor']
|
||||
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[
|
||||
'kv_with_q_tail_mask_idx_tensor']
|
||||
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[
|
||||
'attn_mask_seqlens']
|
||||
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[
|
||||
'head_attn_nomask_seqlens']
|
||||
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
|
||||
'tail_attn_nomask_seqlens']
|
||||
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
||||
'pcp_prefill_mask']
|
||||
return long_seq_metadata
|
||||
|
||||
@@ -94,19 +94,21 @@ class CachedRequestState:
|
||||
class InputBatch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
kernel_block_sizes: Optional[list[list[int]]] = None):
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
kernel_block_sizes: Optional[list[list[int]]] = None,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.max_num_reqs = max_num_reqs
|
||||
@@ -151,7 +153,9 @@ class InputBatch:
|
||||
device=device,
|
||||
block_sizes=block_sizes,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
kernel_sizes=kernel_block_sizes)
|
||||
kernel_sizes=kernel_block_sizes,
|
||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty((max_num_reqs, ),
|
||||
|
||||
@@ -49,6 +49,7 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import (init_ascend_soc_version,
|
||||
prefill_context_parallel_enable,
|
||||
register_ascend_customop, sleep_mode_enabled,
|
||||
try_register_lib)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
@@ -381,9 +382,17 @@ class NPUWorker(WorkerBase):
|
||||
init_distributed_environment(self.parallel_config.world_size,
|
||||
self.rank, self.distributed_init_method,
|
||||
self.local_rank, "hccl")
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
if prefill_context_parallel_enable():
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size,
|
||||
self.parallel_config.prefill_context_parallel_size,
|
||||
self.parallel_config.decode_context_parallel_size)
|
||||
else:
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size,
|
||||
self.parallel_config.decode_context_parallel_size)
|
||||
init_ascend_model_parallel(self.parallel_config)
|
||||
ensure_kv_transfer_initialized(self.vllm_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user