[long_seq_Feat] support chunk prefill (#4158)
### What this PR does / why we need it?
1、qwen GQA attention_v1 optim
2、DeepSeek MLA refactor, all gather q -> all gather kv
3、modelrunner refactor for chunk prefill, we remove some code not use
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
This commit is contained in:
@@ -5,7 +5,6 @@ from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple,
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
@@ -35,14 +34,11 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.attention.utils import (
|
||||
AscendCommonAttentionMetadata, extract_req_dcp_by_chunk_pcp,
|
||||
filter_chunked_req_indices, maybe_save_kv_layer_to_connector,
|
||||
split_decodes_and_prefills, trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
# isort: on
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
split_decodes_and_prefills,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
@@ -122,10 +118,13 @@ class AscendMLAPrefillMetadata:
|
||||
workspace: torch.Tensor
|
||||
chunk_seq_lens: torch.Tensor
|
||||
chunk_seq_lens_npu: torch.Tensor
|
||||
mask_for_non_zero_chunk: Optional[list[bool]] = None
|
||||
max_chunk_num: int = 0
|
||||
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
|
||||
Optional[list[int]]]]]]]] = None
|
||||
# for mla DCP & PCP
|
||||
padded_chunk_seq_lens_npu: torch.Tensor = None
|
||||
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
|
||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||
padded_local_cu_seq_lens: torch.Tensor = None
|
||||
cu_seq_lens_lst: Optional[list[list[int]]] = None
|
||||
chunk_size: Optional[int] = None
|
||||
|
||||
attn_mask: torch.Tensor
|
||||
query_lens: torch.Tensor
|
||||
@@ -140,7 +139,6 @@ class AscendMLAPrefillMetadata:
|
||||
sin: torch.Tensor = None
|
||||
cos: torch.Tensor = None
|
||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -285,6 +283,9 @@ class AscendMLAMetadataBuilder:
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size if prefill_context_parallel_enable(
|
||||
) else 1
|
||||
self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size
|
||||
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
|
||||
0)
|
||||
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
|
||||
@@ -292,16 +293,6 @@ class AscendMLAMetadataBuilder:
|
||||
self.decode_threshold,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
self.seq_mask_pcp_buf = torch.empty(max_num_seqs *
|
||||
self.decode_threshold,
|
||||
self.pcp_size,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
self.seq_mask_dcp_buf = torch.empty(max_num_seqs *
|
||||
self.decode_threshold,
|
||||
self.dcp_size,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
@@ -366,10 +357,6 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
|
||||
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None
|
||||
cp_kv_recover_idx_for_chunk = long_seq_metadata.cp_kv_recover_idx_for_chunk if long_seq_metadata else None
|
||||
local_chunked_kv_lens = long_seq_metadata.local_chunked_kv_lens if long_seq_metadata else None
|
||||
mask_for_non_zero_chunk = long_seq_metadata.mask_for_non_zero_chunk if long_seq_metadata else None
|
||||
max_chunk_num = long_seq_metadata.max_chunk_num if long_seq_metadata else 0
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||
@@ -468,19 +455,75 @@ class AscendMLAMetadataBuilder:
|
||||
dim=1,
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
chunked_context_metadata = \
|
||||
|
||||
if self.dcp_size * self.pcp_size > 1:
|
||||
if num_computed_tokens_of_pcp_dcp is not None:
|
||||
local_context_lens_allranks = torch.tensor(
|
||||
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
|
||||
).reshape(-1, self.dcp_size * self.pcp_size)
|
||||
# Note(qcs): The max local context lengths
|
||||
# padded to `cp_local_block_size`.
|
||||
padded_local_context_lens_cpu = (cdiv(
|
||||
context_lens_cpu,
|
||||
self.cp_virtual_block_size,
|
||||
) * self.cp_local_block_size)
|
||||
padded_local_max_context_chunk_across_ranks = (cdiv(
|
||||
max_context_chunk,
|
||||
self.cp_virtual_block_size,
|
||||
) * self.cp_local_block_size)
|
||||
local_chunk_starts = (
|
||||
torch.arange(num_chunks,
|
||||
dtype=torch.int32).unsqueeze(1).expand(
|
||||
-1, num_prefills) *
|
||||
padded_local_max_context_chunk_across_ranks)
|
||||
local_chunk_ends = torch.min(
|
||||
padded_local_context_lens_cpu.unsqueeze(0),
|
||||
local_chunk_starts +
|
||||
padded_local_max_context_chunk_across_ranks,
|
||||
)
|
||||
padded_local_chunk_seq_lens = (local_chunk_ends -
|
||||
local_chunk_starts).clamp(
|
||||
min=0)
|
||||
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
|
||||
num_chunks,
|
||||
num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(
|
||||
padded_local_chunk_seq_lens,
|
||||
dim=1,
|
||||
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
chunked_context_metadata = \
|
||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
local_chunked_kv_lens=local_chunked_kv_lens,
|
||||
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
|
||||
max_chunk_num=max_chunk_num,
|
||||
)
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=local_chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
|
||||
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
|
||||
local_context_lens_allranks=local_context_lens_allranks.tolist(),
|
||||
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
|
||||
device, non_blocking=True
|
||||
),
|
||||
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
||||
chunk_size=padded_local_max_context_chunk_across_ranks,
|
||||
)
|
||||
else:
|
||||
chunked_context_metadata = \
|
||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos = self.cos_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
@@ -502,7 +545,7 @@ class AscendMLAMetadataBuilder:
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
pcp_metadata=pcp_metadata,
|
||||
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk)
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
@@ -516,7 +559,7 @@ class AscendMLAMetadataBuilder:
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
if self.pcp_size > 1:
|
||||
if self.pcp_size > 1 and self.decode_threshold > 1:
|
||||
block_table = block_table.repeat_interleave(
|
||||
self.decode_threshold, dim=0)
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
@@ -921,26 +964,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
||||
return prefix_output, prefix_lse
|
||||
local_chunked_kv_lens = prefill_metadata.chunked_context.local_chunked_kv_lens
|
||||
mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk
|
||||
max_chunk_num = prefill_metadata.chunked_context.max_chunk_num
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0
|
||||
|
||||
if self.pcp_size > 1:
|
||||
prefix_output = torch.zeros(q_nope.shape[0],
|
||||
self.num_heads,
|
||||
self.v_head_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
prefix_lse = torch.zeros(self.num_heads,
|
||||
q_pe.shape[0],
|
||||
dtype=torch.float32,
|
||||
device=q_pe.device)
|
||||
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
iters = max_chunk_num
|
||||
|
||||
current_seq_len = torch.tensor(prefill_metadata.query_lens,
|
||||
dtype=torch.int32)
|
||||
@@ -948,305 +973,97 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
cache_k_pe = kv_c_and_k_pe_cache[1]
|
||||
num_heads = cache_k_pe.size(2)
|
||||
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
|
||||
# token -> request mapping for building per-token masks when CP>1
|
||||
seq_len1 = torch.tensor(prefill_metadata.query_lens,
|
||||
dtype=torch.int32,
|
||||
device=q_nope.device)
|
||||
seq_len1.mul_(
|
||||
self.pcp_size) # q_full: already padded, divisible by cp_size
|
||||
|
||||
# Select mask: prefer CP prefill mask from metadata; fallback to cached prefill_mask; create if needed.
|
||||
mask_local = None
|
||||
if attn_metadata is not None and attn_metadata.prefill is not None and \
|
||||
attn_metadata.prefill.pcp_metadata is not None and attn_metadata.prefill.pcp_metadata.pcp_prefill_mask is not None:
|
||||
mask_local = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||
else:
|
||||
mask_local = self.prefill_mask
|
||||
if mask_local is None:
|
||||
mask_local = torch.triu(
|
||||
torch.ones(512,
|
||||
512,
|
||||
device=q_nope.device,
|
||||
dtype=q_nope.dtype), 1)
|
||||
self.prefill_mask = mask_local
|
||||
|
||||
# Keep the causal mask; do not override to all-ones.
|
||||
context_starts_rank = None
|
||||
|
||||
for i in range(iters):
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
## DCP mode: each rank processes its own (cp,dcp) historical context slice per request dimension
|
||||
num_requests = len(seq_len1)
|
||||
assert num_requests == len(local_chunked_kv_lens)
|
||||
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||
context_starts_rank = torch.zeros(
|
||||
num_requests, dtype=torch.int32, device=q_nope.device
|
||||
) if context_starts_rank is None else context_starts_rank
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
# chunk_seq_lens will be padded when pcp&dcp
|
||||
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
|
||||
i]
|
||||
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
|
||||
i]
|
||||
seq_len = torch.stack([current_seq_len, context_seq_len])
|
||||
kv_c_normed = torch.empty(toks,
|
||||
num_heads,
|
||||
latent_kv_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
k_pe = torch.empty(toks,
|
||||
num_heads,
|
||||
rope_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
## Calculate tokens each rank should process per request
|
||||
seq_len2_rank = torch.zeros_like(seq_len1, dtype=torch.int32)
|
||||
total_toks = 0
|
||||
|
||||
for req_idx in range(num_requests):
|
||||
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||
continue
|
||||
n_computed_acc = local_chunked_kv_lens[req_idx][i]
|
||||
total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank]
|
||||
seq_len2_rank[req_idx] = n_computed_acc[self.pcp_rank][
|
||||
self.dcp_rank]
|
||||
|
||||
if total_toks > 0:
|
||||
kv_c_normed = torch.empty(total_toks,
|
||||
num_heads,
|
||||
latent_kv_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
k_pe = torch.empty(total_toks,
|
||||
num_heads,
|
||||
rope_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
prefill_metadata.block_table,
|
||||
seq_len2_rank.to(q_nope.device),
|
||||
seq_starts=
|
||||
context_starts_rank, # slot offsets of current chunk in current iteration
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
)
|
||||
seq_len2 = seq_len2_rank.to(q_nope.device)
|
||||
else:
|
||||
# If current rank has no tokens to process, create empty tensors
|
||||
kv_c_normed = torch.empty(0,
|
||||
num_heads,
|
||||
latent_kv_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
k_pe = torch.empty(0,
|
||||
num_heads,
|
||||
rope_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
seq_len2 = torch.zeros((len(seq_len1), ),
|
||||
dtype=torch.int32,
|
||||
device=q_nope.device)
|
||||
seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()])
|
||||
for req_idx in range(num_requests):
|
||||
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||
continue
|
||||
context_starts_rank[req_idx] += local_chunked_kv_lens[
|
||||
req_idx][i][self.pcp_rank][self.dcp_rank]
|
||||
else:
|
||||
# Original logic: ChunkPrefill-only mode
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
|
||||
if self.dcp_size * self.pcp_size > 1:
|
||||
context_seq_len_npu = prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[
|
||||
i]
|
||||
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
|
||||
i]
|
||||
seq_len = torch.stack([current_seq_len, context_seq_len])
|
||||
|
||||
kv_c_normed = torch.empty(toks,
|
||||
num_heads,
|
||||
latent_kv_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
k_pe = torch.empty(toks,
|
||||
num_heads,
|
||||
rope_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
prefill_metadata.block_table,
|
||||
context_seq_len_npu,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
)
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
prefill_metadata.block_table,
|
||||
context_seq_len_npu,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1)
|
||||
if self.dcp_size > 1:
|
||||
cache_kv_c_k_pe = get_dcp_group().all_gather(
|
||||
cache_kv_c_k_pe, 0)
|
||||
|
||||
if self.pcp_size > 1:
|
||||
cache_kv_c_k_pe = get_pcp_group().all_gather(
|
||||
cache_kv_c_k_pe, 0)
|
||||
|
||||
if self.dcp_size * self.pcp_size > 1:
|
||||
allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed, k_pe = self._reorg_kvcache(
|
||||
allgatered_kv_c_normed,
|
||||
allgatered_k_pe,
|
||||
padded_local_chunk_seq_lens_lst=prefill_metadata.
|
||||
chunked_context.padded_local_chunk_seq_lens[i],
|
||||
local_context_lens_allranks=prefill_metadata.
|
||||
chunked_context.local_context_lens_allranks,
|
||||
sum_seq_len=prefill_metadata.chunked_context.
|
||||
cu_seq_lens_lst[i][-1],
|
||||
max_seq_len=prefill_metadata.chunked_context.
|
||||
max_seq_lens[i],
|
||||
chunk_size=prefill_metadata.chunked_context.chunk_size,
|
||||
chunk_idx=i,
|
||||
toks=toks,
|
||||
)
|
||||
|
||||
kv_c_normed = kv_c_normed.squeeze()
|
||||
if self.dcp_size > 1:
|
||||
# DCP mode: first all_gather within DCP group, let each rank in CP group share complete sequence blocks
|
||||
# Step 1: DCP all_gather latent
|
||||
kv_c_k_pe_local = torch.cat(
|
||||
[kv_c_normed, k_pe.squeeze()],
|
||||
dim=-1) # [local_toks, latent_dim + rope_dim]
|
||||
|
||||
# Step 2: use all_gather_into_tensor_uneven (gather + cat)
|
||||
req_dcp_sizes = extract_req_dcp_by_chunk_pcp(
|
||||
local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank
|
||||
) # need to know num tokens of each rank in dcp group before all_gather # [reqs, dcp]
|
||||
assert len(req_dcp_sizes) == num_requests and all(
|
||||
len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes)
|
||||
total_toks = np.sum(np.array(req_dcp_sizes))
|
||||
latent_rope_dim = kv_c_k_pe_local.size(-1)
|
||||
kv_c_k_pe_full = torch.empty((total_toks, latent_rope_dim),
|
||||
device=kv_c_k_pe_local.device,
|
||||
dtype=kv_c_k_pe_local.dtype)
|
||||
|
||||
kv_c_k_pe_full_list = [None for _ in range(self.dcp_size)]
|
||||
dist.all_gather_object(kv_c_k_pe_full_list,
|
||||
kv_c_k_pe_local,
|
||||
group=self.dcp_group)
|
||||
kv_c_k_pe_full_list = [
|
||||
kv_c_k_pe for kv_c_k_pe in kv_c_k_pe_full_list
|
||||
if kv_c_k_pe is not None and kv_c_k_pe.numel() > 0
|
||||
]
|
||||
if len(kv_c_k_pe_full_list) > 0:
|
||||
kv_c_k_pe_full = torch.cat(kv_c_k_pe_full_list, dim=0)
|
||||
if len(kv_c_k_pe_full.shape) == 1:
|
||||
assert total_toks == 1
|
||||
kv_c_k_pe_full = kv_c_k_pe_full.unsqueeze(0)
|
||||
assert kv_c_k_pe_full.shape[
|
||||
0] == total_toks and kv_c_k_pe_full.shape[
|
||||
1] == latent_rope_dim
|
||||
kv_c_normed_full, k_pe_full = torch.split(
|
||||
kv_c_k_pe_full, [latent_kv_dim, rope_dim], dim=-1)
|
||||
|
||||
# Step 3: process complete sequence with TP projection to get current rank's head slice
|
||||
# Case that no kv_cache has been stored on this CP rank(after dcp all_gather), no need to do following computation.
|
||||
if total_toks == 0:
|
||||
continue
|
||||
kv_nope = self.kv_b_proj(kv_c_normed_full)[0].view(
|
||||
-1, self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe_full.unsqueeze(1).expand((*k_nope.shape[:-1], -1))
|
||||
|
||||
seq_len2 = torch.tensor(np.sum(np.array(req_dcp_sizes),
|
||||
axis=1),
|
||||
dtype=torch.int32,
|
||||
device=q_nope.device) # [reqs]
|
||||
seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()])
|
||||
else:
|
||||
# Non-DCP mode: use TP-split projection
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
|
||||
if self.pcp_size > 1:
|
||||
# Case that no kv_cache has been stored on this CP rank, no need to do following computation.
|
||||
if torch.all(seq_len2 == 0).item():
|
||||
continue
|
||||
# PCP mode: first compute this rank's contribution to the chunk
|
||||
if i == 0:
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask_local,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse)
|
||||
continue
|
||||
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask_local,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=prefix_output,
|
||||
prev_lse=prefix_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse)
|
||||
|
||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||
else:
|
||||
assert not torch.all(context_seq_len == 0).item()
|
||||
# compute this chunk block then update prefix tensors to keep shapes consistent
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask_local,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=prefix_output,
|
||||
prev_lse=prefix_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse)
|
||||
|
||||
# CP dimension all_gather and fusion
|
||||
if self.pcp_size > 1:
|
||||
# filter non-zero chunk part of prefix_output
|
||||
seq_len1_cpu = seq_len1.cpu()
|
||||
filtered_indices = filter_chunked_req_indices(
|
||||
seq_len1_cpu, mask_for_non_zero_chunk)
|
||||
prefix_output_filtered = prefix_output[filtered_indices, :, :]
|
||||
prefix_lse_filtered = prefix_lse[:, filtered_indices]
|
||||
|
||||
# normalize prefix LSE to [bs, heads, 1] for stable updates
|
||||
prefix_lse_filtered_bt = prefix_lse_filtered.permute(
|
||||
1, 0).unsqueeze(-1).contiguous(
|
||||
) if prefix_lse_filtered is not None else None
|
||||
out_lse_local = torch.cat(
|
||||
[prefix_output_filtered, prefix_lse_filtered_bt], dim=-1)
|
||||
out_lse_list = [
|
||||
torch.empty_like(out_lse_local) for _ in range(self.pcp_size)
|
||||
]
|
||||
dist.all_gather(out_lse_list, out_lse_local, group=self.pcp_group)
|
||||
prefix_output_filtered = None
|
||||
prefix_lse_filtered_bt = None
|
||||
for r in range(self.pcp_size):
|
||||
out_lse_r = out_lse_list[r]
|
||||
if torch.all(out_lse_r == 0).item():
|
||||
continue
|
||||
out_r, lse_r = torch.split(out_lse_r, [self.v_head_dim, 1],
|
||||
dim=-1)
|
||||
token_mask = torch.ones([out_r.size(0)],
|
||||
dtype=torch.uint8,
|
||||
device=out_r.device)
|
||||
prefix_output_filtered, prefix_lse_filtered_bt = self._update_out_and_lse(
|
||||
prefix_output_filtered, prefix_lse_filtered_bt, out_r,
|
||||
lse_r, token_mask)
|
||||
# convert lse back to [heads, bs]
|
||||
assert prefix_output_filtered is not None and prefix_lse_filtered_bt is not None
|
||||
prefix_lse_filtered = prefix_lse_filtered_bt.squeeze(-1).permute(
|
||||
1, 0).contiguous()
|
||||
|
||||
prefix_output[filtered_indices, :, :] = prefix_output_filtered.to(
|
||||
prefix_output.dtype)
|
||||
prefix_lse[:, filtered_indices] = prefix_lse_filtered.to(
|
||||
prefix_lse.dtype)
|
||||
|
||||
mask = self.prefill_mask
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=prefix_output,
|
||||
prev_lse=prefix_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse)
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _forward_prefill(
|
||||
@@ -1814,8 +1631,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
|
||||
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||
|
||||
output_head, head_lse = self._attention_with_mask_and_nomask(
|
||||
output_head, lse_head = self._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
||||
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
||||
k_nope=k_nope,
|
||||
@@ -1827,7 +1643,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||
mask=mask)
|
||||
|
||||
output_tail, tail_lse = self._attention_with_mask_and_nomask(
|
||||
output_tail, lse_tail = self._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
||||
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
|
||||
k_nope=k_nope,
|
||||
@@ -1840,86 +1656,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
mask=mask)
|
||||
|
||||
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
|
||||
output = torch.index_select(
|
||||
attn_output = torch.index_select(
|
||||
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
|
||||
attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1),
|
||||
1, q_full_idx)
|
||||
|
||||
# Synchronize and reorder LSE for subsequent chunked context accumulation
|
||||
attn_lse = torch.cat([head_lse, tail_lse], dim=1)
|
||||
attn_lse = attn_lse[:, q_full_idx]
|
||||
output, _ = self._compute_prefill_context( \
|
||||
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||
|
||||
# Post-processing: keep [tokens, H, V] shape and perform chunked context accumulation if needed
|
||||
if attn_metadata.prefill is not None and \
|
||||
attn_metadata.prefill.chunked_context is not None:
|
||||
# q all_gather
|
||||
q_nope_full = get_pcp_group().all_gather(q_nope.contiguous(), 0)
|
||||
q_pe_full = get_pcp_group().all_gather(q_pe.contiguous(), 0)
|
||||
q_nope_full = torch.index_select(
|
||||
q_nope_full, 0,
|
||||
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
|
||||
q_pe_full = torch.index_select(
|
||||
q_pe_full, 0,
|
||||
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
|
||||
attn_output_pre = output.view(num_tokens, self.num_heads,
|
||||
self.v_head_dim)
|
||||
attn_output_pre_full, attn_lse_full = self._compute_prefill_context(
|
||||
q_nope_full,
|
||||
q_pe_full,
|
||||
kv_c_and_k_pe_cache,
|
||||
self.qk_rope_head_dim,
|
||||
attn_metadata,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
# reorder back && extract output + lse result of each cp rank
|
||||
inverse_idx = torch.argsort(
|
||||
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
|
||||
attn_output_pre_full = torch.index_select(attn_output_pre_full, 0,
|
||||
inverse_idx)
|
||||
attn_lse_full = torch.index_select(attn_lse_full, 1, inverse_idx)
|
||||
attn_output_pre_new = attn_output_pre_full[
|
||||
self.pcp_rank * num_tokens:(self.pcp_rank + 1) *
|
||||
num_tokens, :, :]
|
||||
attn_lse_new = attn_lse_full[:, self.pcp_rank *
|
||||
num_tokens:(self.pcp_rank + 1) *
|
||||
num_tokens]
|
||||
|
||||
# update(output_origin, output_new)
|
||||
assert attn_output_pre_new.shape == attn_output_pre.shape and attn_lse_new.shape == attn_lse.shape
|
||||
seq_len = torch.tensor(attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32)
|
||||
mask_for_non_zero_chunk = attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk
|
||||
filtered_indices = filter_chunked_req_indices(
|
||||
seq_len, mask_for_non_zero_chunk)
|
||||
attn_output_pre_filtered = attn_output_pre[filtered_indices, :, :]
|
||||
attn_lse_filtered = attn_lse[:, filtered_indices]
|
||||
attn_output_pre_new = attn_output_pre_new[filtered_indices, :, :]
|
||||
attn_lse_new = attn_lse_new[:, filtered_indices]
|
||||
|
||||
# normalize prefix LSE to [bs, heads, 1] for stable updates
|
||||
attn_lse_filtered = attn_lse_filtered.permute(1, 0).unsqueeze(-1)
|
||||
attn_lse_new = attn_lse_new.permute(1, 0).unsqueeze(-1)
|
||||
token_mask = torch.ones([attn_lse_new.size(0)],
|
||||
dtype=torch.uint8,
|
||||
device=attn_lse_new.device)
|
||||
attn_output_pre_filtered, attn_lse_filtered = self._update_out_and_lse(
|
||||
attn_output_pre_filtered, attn_lse_filtered,
|
||||
attn_output_pre_new, attn_lse_new, token_mask)
|
||||
# convert lse back to [heads, bs]
|
||||
attn_lse_filtered = attn_lse_filtered.squeeze(-1).permute(
|
||||
1, 0).contiguous()
|
||||
|
||||
attn_output_pre[
|
||||
filtered_indices, :, :] = attn_output_pre_filtered.to(
|
||||
attn_output_pre.dtype)
|
||||
attn_lse[:,
|
||||
filtered_indices] = attn_lse_filtered.to(attn_lse.dtype)
|
||||
|
||||
attn_output_pre = attn_output_pre.to(q_nope.dtype)
|
||||
output = attn_output_pre.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
else:
|
||||
output = output.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
output = output.reshape([num_tokens, self.num_heads * self.v_head_dim])
|
||||
|
||||
return output
|
||||
|
||||
@@ -2164,32 +1909,75 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
return attn_out_lse_list
|
||||
|
||||
# TODO use update op to replace this
|
||||
def _update_out_and_lse(
|
||||
def _reorg_kvcache(
|
||||
self,
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
block_out: torch.Tensor,
|
||||
block_lse: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
):
|
||||
if out is None:
|
||||
out = block_out.to(torch.float32)
|
||||
lse = block_lse
|
||||
else:
|
||||
if mask is None:
|
||||
mask = torch.ones([block_out.size(0)],
|
||||
dtype=torch.uint8,
|
||||
device=block_out.device)
|
||||
out_mask = mask[:, None, None].expand_as(block_out)
|
||||
lse_mask = mask[:, None, None].expand_as(block_lse)
|
||||
block_out = block_out.to(torch.float32)
|
||||
out_without_update = out.clone()
|
||||
lse_without_update = lse.clone()
|
||||
|
||||
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
|
||||
lse = lse - F.logsigmoid(lse - block_lse)
|
||||
# mask
|
||||
out = torch.where(out_mask, out, out_without_update)
|
||||
lse = torch.where(lse_mask, lse, lse_without_update)
|
||||
return out, lse
|
||||
allgatered_kv_c_normed: torch.Tensor,
|
||||
allgatered_k_pe: torch.Tensor,
|
||||
padded_local_chunk_seq_lens_lst: list[int],
|
||||
local_context_lens_allranks: list[list[int]],
|
||||
sum_seq_len: int,
|
||||
max_seq_len: int,
|
||||
chunk_size: int,
|
||||
chunk_idx: int,
|
||||
toks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
reorg and unpad kvcache after cp local gather to tp layout for attn kernel.
|
||||
e.g.
|
||||
kv_c_normed in rank0 = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...]
|
||||
kv_c_normed in rank1 = [T0_4, T0_5, pad, pad, T1_2, pad, ...]
|
||||
allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...,
|
||||
T0_4, T0_5, pad, pad, T1_2, pad, ...]
|
||||
-> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5,
|
||||
T1_0, T1_1, T1_2, ...]
|
||||
Args:
|
||||
padded_local_chunk_seq_lens_lst: local chunk context lengths
|
||||
under current CP rank.
|
||||
local_context_lens_allranks: local context lengths on each CP rank.
|
||||
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
|
||||
max_seq_len: the max value of cp_chunk_seq_lens_lst.
|
||||
chunk_size: the local padded max context chunk from
|
||||
chunked_context_metadata building.
|
||||
chunk_idx: chunk idx of chunked_prefill.
|
||||
toks: the number of tokens for local gather cache.
|
||||
"""
|
||||
kv_c_segments = []
|
||||
k_pe_segments = []
|
||||
src_token_idx = 0
|
||||
max_seq_len_check = 0
|
||||
for padded_local_chunk_seq_len, local_context_lens in zip(
|
||||
padded_local_chunk_seq_lens_lst, local_context_lens_allranks):
|
||||
cur_seq_len = 0
|
||||
for rank, local_context_len in enumerate(local_context_lens):
|
||||
# Note(qcs): We split the context into multiple chunks,
|
||||
# depending on the size of the workspace.
|
||||
# local_context in dcp0: |-----------------|
|
||||
# local_context in dcp1: |--------------|
|
||||
# n*padded_local_chunk: |-----|-----|-----|
|
||||
# local_chunk_len in dcp1: |-----|-----|--|
|
||||
# so we need update the last chunk length in dcp1.
|
||||
local_chunk_len = min(
|
||||
max(0, local_context_len - chunk_idx * chunk_size),
|
||||
padded_local_chunk_seq_len,
|
||||
)
|
||||
if local_chunk_len != 0:
|
||||
kv_c_segment = allgatered_kv_c_normed[rank * toks +
|
||||
src_token_idx:rank *
|
||||
toks +
|
||||
src_token_idx +
|
||||
local_chunk_len]
|
||||
k_pe_segment = allgatered_k_pe[rank * toks +
|
||||
src_token_idx:rank * toks +
|
||||
src_token_idx +
|
||||
local_chunk_len]
|
||||
kv_c_segments.append(kv_c_segment)
|
||||
k_pe_segments.append(k_pe_segment)
|
||||
cur_seq_len += local_chunk_len
|
||||
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
|
||||
src_token_idx += padded_local_chunk_seq_len
|
||||
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
|
||||
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
|
||||
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
|
||||
assert reorganized_k_pe.shape[0] == sum_seq_len
|
||||
assert max_seq_len_check == max_seq_len
|
||||
return reorganized_kv_c_normed, reorganized_k_pe
|
||||
|
||||
Reference in New Issue
Block a user