### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/attention_mask.py` |
| `vllm_ascend/attention/attention_v1.py` |
| `vllm_ascend/attention/context_parallel/attention_cp.py` |
| `vllm_ascend/attention/context_parallel/common_cp.py` |
| `vllm_ascend/attention/context_parallel/mla_cp.py` |
| `vllm_ascend/attention/utils.py` |
| `vllm_ascend/batch_invariant.py` |
| `vllm_ascend/device/device_op.py` |
| `vllm_ascend/device_allocator/camem.py` |
| `vllm_ascend/envs.py` |
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -1,18 +1,15 @@
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_config,
|
||||
get_ascend_device_type)
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_config, get_ascend_device_type
|
||||
|
||||
|
||||
def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
|
||||
@@ -21,6 +18,7 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
|
||||
if get_ascend_device_type() == AscendDeviceType.A5:
|
||||
return False
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
|
||||
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
|
||||
if cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
return False
|
||||
@@ -31,8 +29,7 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
|
||||
@lru_cache(maxsize=1)
|
||||
def enable_cp():
|
||||
prefill_config = get_current_vllm_config().parallel_config
|
||||
return prefill_config.prefill_context_parallel_size > 1 \
|
||||
or prefill_config.decode_context_parallel_size > 1
|
||||
return prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,13 +39,14 @@ class AscendPrefillContextParallelMetadata:
|
||||
|
||||
Contains index tensors and sequence lengths for PCP operations.
|
||||
"""
|
||||
|
||||
pcp_allgather_restore_idx: torch.Tensor = None
|
||||
|
||||
cp_kv_recover_idx_for_chunk: torch.Tensor = None
|
||||
|
||||
num_actual_tokens_pcp_padded: int = 0
|
||||
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None
|
||||
|
||||
q_head_idx_tensor: torch.Tensor = None
|
||||
|
||||
@@ -85,6 +83,7 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
|
||||
For many of the tensors we keep both NPU and CPU versions.
|
||||
"""
|
||||
|
||||
# CPU tensor of sequence lengths for host-side operations.
|
||||
# E.g., tensor([128, 256, 64]) for 3 requests with different seq lengths.
|
||||
seq_lens_cpu: torch.Tensor = None
|
||||
@@ -115,20 +114,17 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
num_input_tokens: int = 0
|
||||
|
||||
# Metadata for Prefill Context Parallelism (PCP) operations.
|
||||
prefill_context_parallel_metadata: Optional[
|
||||
AscendPrefillContextParallelMetadata] = None
|
||||
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata | None = None
|
||||
|
||||
# TODO: Remove it when vLLM no longer uses this function.
|
||||
def unpadded(self, num_actual_tokens: int,
|
||||
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
|
||||
def unpadded(self, num_actual_tokens: int, num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
|
||||
# This only use to eagle now. It will be use to enforce_eager in future.
|
||||
return AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_actual_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_actual_reqs + 1],
|
||||
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_actual_reqs],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs],
|
||||
num_computed_tokens_cpu=self.
|
||||
num_computed_tokens_cpu[:num_actual_reqs],
|
||||
num_computed_tokens_cpu=self.num_computed_tokens_cpu[:num_actual_reqs],
|
||||
num_reqs=num_actual_reqs,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=self.max_query_len,
|
||||
@@ -144,14 +140,14 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
attn_state=self.attn_state,
|
||||
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
|
||||
num_input_tokens=self.num_input_tokens,
|
||||
prefill_context_parallel_metadata=self.
|
||||
prefill_context_parallel_metadata,
|
||||
max_seq_len=self.max_seq_len)
|
||||
prefill_context_parallel_metadata=self.prefill_context_parallel_metadata,
|
||||
max_seq_len=self.max_seq_len,
|
||||
)
|
||||
|
||||
|
||||
def filter_chunked_req_indices(
|
||||
seq_len: torch.Tensor,
|
||||
mask_for_non_zero_chunk: Optional[List[bool]],
|
||||
mask_for_non_zero_chunk: list[bool] | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
filter the reqs which are doing real chunk_prefill.
|
||||
@@ -162,14 +158,15 @@ def filter_chunked_req_indices(
|
||||
Returns:
|
||||
filtered_indices: the real chunked req's indices
|
||||
"""
|
||||
assert mask_for_non_zero_chunk is not None and len(seq_len) == len(
|
||||
mask_for_non_zero_chunk)
|
||||
assert mask_for_non_zero_chunk is not None and len(seq_len) == len(mask_for_non_zero_chunk)
|
||||
offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0)
|
||||
filtered_indices = torch.cat([
|
||||
torch.arange(offsets[i], offsets[i] + seq_len[i])
|
||||
for i in range(len(mask_for_non_zero_chunk))
|
||||
if mask_for_non_zero_chunk[i]
|
||||
])
|
||||
filtered_indices = torch.cat(
|
||||
[
|
||||
torch.arange(offsets[i], offsets[i] + seq_len[i])
|
||||
for i in range(len(mask_for_non_zero_chunk))
|
||||
if mask_for_non_zero_chunk[i]
|
||||
]
|
||||
)
|
||||
return filtered_indices
|
||||
|
||||
|
||||
@@ -195,12 +192,9 @@ def split_decodes_and_prefills(
|
||||
num_prefill_tokens: The number of tokens in the prefill requests.
|
||||
"""
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu \
|
||||
if long_seq_metadata else None
|
||||
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full \
|
||||
if long_seq_metadata else 0
|
||||
max_query_len = common_attn_metadata.max_query_len \
|
||||
if max_query_len_pcp_full == 0 else max_query_len_pcp_full
|
||||
query_lens_pcp_full = long_seq_metadata.query_lens_pcp_full_cpu if long_seq_metadata else None
|
||||
max_query_len_pcp_full = long_seq_metadata.max_query_len_pcp_full if long_seq_metadata else 0
|
||||
max_query_len = common_attn_metadata.max_query_len if max_query_len_pcp_full == 0 else max_query_len_pcp_full
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
||||
@@ -208,8 +202,7 @@ def split_decodes_and_prefills(
|
||||
if max_query_len <= decode_threshold:
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) \
|
||||
if query_lens_pcp_full is None else query_lens_pcp_full
|
||||
query_lens = (query_start_loc[1:] - query_start_loc[:-1]) if query_lens_pcp_full is None else query_lens_pcp_full
|
||||
is_prefill = query_lens > decode_threshold
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
@@ -238,7 +231,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
|
||||
def maybe_save_kv_layer_to_connector(
|
||||
layer_name: str,
|
||||
kv_cache_layer: List[torch.Tensor],
|
||||
kv_cache_layer: list[torch.Tensor],
|
||||
):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return
|
||||
@@ -264,8 +257,7 @@ def trans_rope_weight(weight, rope_dim):
|
||||
return weight.contiguous()
|
||||
nope_part = weight[..., :-rope_dim, :]
|
||||
rope_part = weight[..., -rope_dim:, :]
|
||||
reordered_rope_part = torch.cat(
|
||||
(rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
|
||||
reordered_rope_part = torch.cat((rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
|
||||
return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous()
|
||||
|
||||
|
||||
@@ -278,12 +270,9 @@ def transdata(nd_mat, block_size: tuple = (16, 16)):
|
||||
nz_mat = torch.permute(
|
||||
torch.reshape(
|
||||
nd_mat,
|
||||
(r // block_size[0], block_size[0], c // block_size[1],
|
||||
block_size[1]),
|
||||
(r // block_size[0], block_size[0], c // block_size[1], block_size[1]),
|
||||
),
|
||||
[2, 0, 1, 3],
|
||||
)
|
||||
nz_mat = torch.reshape(
|
||||
nz_mat,
|
||||
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
|
||||
nz_mat = torch.reshape(nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
|
||||
return nz_mat
|
||||
|
||||
Reference in New Issue
Block a user