[long_seq_Feat] support chunk prefill (#4158)

### What this PR does / why we need it?
1、qwen GQA attention_v1 optim
2、DeepSeek MLA refactor, all gather q -> all gather kv 
3、modelrunner refactor for chunk prefill, we remove some code not use

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

---------

Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
This commit is contained in:
LookAround0301
2025-11-14 08:43:37 +08:00
committed by GitHub
parent 7294f89e43
commit 5ec96fd46c
6 changed files with 419 additions and 941 deletions

View File

@@ -484,9 +484,6 @@ class TestAscendMLAImpl(TestBase):
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]
chunk_ctx.max_chunk_num = 1
chunk_ctx.mask_for_non_zero_chunk = [True]
chunk_ctx.local_chunked_kv_lens = [[[[8]]]]
prefill_meta = MagicMock()
prefill_meta.chunked_context = chunk_ctx

View File

@@ -44,7 +44,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
extract_req_dcp_by_chunk_pcp,
filter_chunked_req_indices,
split_decodes_and_prefills)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
@@ -169,10 +168,10 @@ class AscendMetadataForPrefill:
@dataclass
class ChunkedContextMetadata:
actual_chunk_seq_lengths: list[int]
mask_for_non_zero_chunk: Optional[list[bool]] = None
max_chunk_num: int = 0
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
Optional[list[int]]]]]]]] = None
actual_seq_lengths_kv: list[int]
starts: torch.Tensor
chunked_req_mask: Optional[list[bool]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
kv_inverse_idx_for_chunk: Optional[list[int]] = None
@@ -286,25 +285,7 @@ class AscendAttentionMetadataBuilder:
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
scheduler_config = vllm_config.scheduler_config
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * self.model_config.max_model_len,
4 * scheduler_config.max_num_seqs * self.block_size),
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
def reorder_batch(self, input_batch,
scheduler_output: "SchedulerOutput") -> bool:
@@ -385,6 +366,8 @@ class AscendAttentionMetadataBuilder:
prefill_metadata = None
decode_metadata = None
if common_long_seq_metadata is not None:
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None
chunked_context_metadata = None
if num_prefills > 0:
query_lens = query_lens[num_decode_tokens:]
@@ -394,18 +377,39 @@ class AscendAttentionMetadataBuilder:
pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
local_context_lens_allranks = torch.tensor(
num_computed_tokens_of_pcp_dcp
)[num_decodes:num_reqs].to(
self.device).to(dtype=torch.int32)
local_chunked_kv_lens_rank = local_context_lens_allranks[:,
self
.
pcp_rank,
self
.
dcp_rank]
actual_seq_lengths_kv = torch.cumsum(
local_chunked_kv_lens_rank, dim=0).tolist()
chunked_req_mask = self._get_chunked_req_mask(
local_context_lens_allranks)
local_chunk_starts = torch.zeros(
(len(local_context_lens_allranks)),
dtype=torch.int32,
device=self.device)
cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk
kv_inverse_idx_for_chunk = torch.argsort(
cp_kv_recover_idx_for_chunk.to(torch.float32)
) if cp_kv_recover_idx_for_chunk is not None else None
chunked_context_metadata = \
AscendMetadataForPrefill.ChunkedContextMetadata(
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
mask_for_non_zero_chunk=common_long_seq_metadata.mask_for_non_zero_chunk,
local_chunked_kv_lens=common_long_seq_metadata.local_chunked_kv_lens,
actual_seq_lengths_kv=actual_seq_lengths_kv,
chunked_req_mask=chunked_req_mask,
starts=local_chunk_starts,
local_context_lens_allranks=local_context_lens_allranks,
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
max_chunk_num=common_long_seq_metadata.max_chunk_num
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk
)
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
@@ -445,8 +449,6 @@ class AscendAttentionMetadataBuilder:
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0))
if num_decodes > 0:
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None
num_computed_tokens_array = np.array(
num_computed_tokens_of_pcp_dcp)
num_computed_tokens_array = num_computed_tokens_array[:
@@ -483,6 +485,19 @@ class AscendAttentionMetadataBuilder:
decode_meta=decode_metadata)
return attn_metadata
def _get_chunked_req_mask(self, local_context_lens_allranks) -> List[bool]:
"""
given 4-d list [req][pcp][dcp], return:
1. if each req has any chunk (list[bool])
"""
assert local_context_lens_allranks is not None
if len(local_context_lens_allranks) == 0:
return []
chunked_req_mask = [(req.sum() > 0).item()
for req in local_context_lens_allranks
if req is not None]
return chunked_req_mask
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
@@ -1205,11 +1220,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
attn_lse_full_chunk = attn_lse_full_chunk[
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape
seq_len = attn_metadata.query_lens.detach().clone()
filtered_indices = filter_chunked_req_indices(
seq_len,
attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk)
seq_len, attn_metadata.prefill.chunked_context.chunked_req_mask)
attn_output_prefill_filtered = current_attn_output_prefill[
filtered_indices, :, :]
@@ -1221,18 +1236,23 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_output_filtered = self._npu_attn_out_lse_update(
attn_lse_prefill_filtered, attn_lse_full_chunk,
attn_output_prefill_filtered, attn_output_full_chunk)
current_attn_output_prefill[
filtered_indices, :, :] = attn_output_filtered.to(
current_attn_output_prefill.dtype)
def _prefill_query_all_gather(self, attn_metadata, prefill_query):
prefill_query_all = get_pcp_group().all_gather(prefill_query.contiguous(),
0) \
if self.pcp_size > 1 else prefill_query
prefill_query_all = torch.index_select(prefill_query_all,
if self.dcp_size > 1:
prefill_query = get_dcp_group().all_gather(prefill_query, 1)
if self.pcp_size > 1:
prefill_query = get_pcp_group().all_gather(prefill_query, 0)
prefill_query_all = torch.index_select(prefill_query,
0,
attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk) \
if self.pcp_size > 1 else prefill_query_all
if self.pcp_size > 1 else prefill_query
return prefill_query_all
def _compute_prefill_context(self, query: torch.Tensor,
@@ -1243,217 +1263,132 @@ class AscendAttentionBackendImpl(AttentionImpl):
assert attn_metadata.prefill is not None
assert attn_metadata.prefill.chunked_context is not None
prefill_metadata = attn_metadata.prefill
local_chunked_kv_lens = attn_metadata.prefill.chunked_context.local_chunked_kv_lens
mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk
max_chunk_num = prefill_metadata.chunked_context.max_chunk_num
local_chunked_kv_lens = prefill_metadata.chunked_context.local_context_lens_allranks
assert local_chunked_kv_lens is not None
assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0
local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank,
self.dcp_rank]
iters = max_chunk_num
# Keep the causal mask; do not override to all-ones. [req_id][chunk_id][cp-rank][dcp_rank]
context_starts_rank = None
prefix_output_list = []
prefix_lse_list = []
for i in range(iters):
key, value, seq_lens_current_chunk_rank = self._load_kv_for_chunk(
attn_metadata, kv_cache, context_starts_rank, i,
local_chunked_kv_lens, prefill_metadata, query)
# 2. Attention computation
if seq_lens_current_chunk_rank is None or torch.all(
seq_lens_current_chunk_rank == 0).item():
prefix_output = torch.full(
(query.size(0), self.num_heads, self.head_size),
fill_value=0,
dtype=query.dtype,
device=query.device)
prefix_lse = torch.full((query.size(0), self.num_heads, 1),
fill_value=0,
dtype=torch.float32,
device=query.device)
else:
actual_seq_lengths_kv = torch.cumsum(
seq_lens_current_chunk_rank, dim=0).tolist()
prefix_output, prefix_lse = torch.ops.npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="TND", #
atten_mask=None,
scale=self.scale,
sparse_mode=0,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=attn_metadata.prefill.chunked_context.
actual_chunk_seq_lengths)
prefix_output_list.append(prefix_output)
prefix_lse_list.append(prefix_lse)
# 3. update attn-out & lse
prefix_output, prefix_lse = self._update_attn_out_lse_in_chunks(
prefix_output_list, prefix_lse_list)
self._update_attn_out_lse_in_pcp(attn_metadata, prefix_output,
prefix_lse)
return prefix_output, prefix_lse
def _update_attn_out_lse_in_chunks(self, prefix_output_list,
prefix_lse_list):
# update output and lse
if len(prefix_output_list) > 1:
prefix_output, prefix_lse = self._update_out_and_lse(
torch.stack(prefix_output_list, dim=0),
torch.stack(prefix_lse_list, dim=0))
key, value = self._load_kv_for_chunk(attn_metadata, kv_cache,
local_chunked_kv_lens_rank, query)
if self.dcp_size > 1:
num_heads = self.num_heads * self.dcp_size
else:
prefix_output = prefix_output_list[0]
prefix_lse = prefix_lse_list[0]
num_heads = self.num_heads
prefix_chunk_output = torch.full(
(query.size(0), num_heads, self.head_size),
fill_value=0,
dtype=query.dtype,
device=query.device)
prefix_chunk_lse = torch.full((query.size(0), num_heads, 1),
fill_value=-torch.inf,
dtype=torch.float32,
device=query.device)
if not torch.all(local_chunked_kv_lens_rank == 0).item():
prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="TND",
atten_mask=None,
scale=self.scale,
sparse_mode=0,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=prefill_metadata.chunked_context.
actual_seq_lengths_kv,
actual_seq_lengths=attn_metadata.prefill.chunked_context.
actual_chunk_seq_lengths)
prefix_output, prefix_lse = self._update_chunk_attn_out_lse(
prefix_chunk_output, prefix_chunk_lse)
return prefix_output, prefix_lse
def _update_attn_out_lse_in_pcp(self, attn_metadata, prefix_output,
prefix_lse):
def _update_chunk_attn_out_lse(self, prefix_chunk_output,
prefix_chunk_lse):
# CP dimension all_gather and fusion
if self.pcp_size > 1:
# filter non-zero chunk part of prefix_output
current_seq_lens = attn_metadata.query_lens.detach().clone()
current_seq_lens.mul_(self.pcp_size) # q_full
current_seq_lens_cpu = current_seq_lens.cpu()
filtered_indices = filter_chunked_req_indices(
current_seq_lens_cpu,
attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk)
prefix_output_filtered = prefix_output[filtered_indices, :, :]
prefix_lse_filtered = prefix_lse[filtered_indices, :, :]
chunk_attn_out_lse = torch.cat([prefix_chunk_output, prefix_chunk_lse],
dim=-1)
out_lse_local = torch.cat(
[prefix_output_filtered, prefix_lse_filtered], dim=-1)
if self.dcp_size > 1:
chunk_attn_out_lse = chunk_attn_out_lse.permute([1, 2,
0]).contiguous()
attn_out_lse_all2all = torch.empty_like(chunk_attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all,
chunk_attn_out_lse,
group=self.dcp_group)
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
if self.pcp_size > 1:
chunk_attn_out_lse = attn_out_lse_all2all.contiguous()
attn_out_lse_list = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
if self.pcp_size > 1:
attn_out_lse_list = [
torch.empty_like(out_lse_local) for _ in range(self.pcp_size)
torch.empty_like(chunk_attn_out_lse)
for _ in range(self.pcp_size)
]
dist.all_gather(attn_out_lse_list,
out_lse_local,
chunk_attn_out_lse,
group=self.pcp_group)
attn_out_lse_allgather = torch.stack(
attn_out_lse_list,
dim=0) # [pcp, batch_size, num_heads, head_size+1]
attn_out_allgather, attn_lse_allgather = torch.split(
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
prefix_output_filtered, prefix_lse_filtered = self._update_out_and_lse(
attn_out_allgather, attn_lse_allgather)
if self.dcp_size > 1 and self.pcp_size > 1:
attn_out_lse_list_pcp_dcp = []
for s in attn_out_lse_list:
attn_out_lse_list_split = list(
torch.chunk(s, self.dcp_size, dim=1))
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
attn_out_lse_list = attn_out_lse_list_pcp_dcp
prefix_output[filtered_indices, :, :] = prefix_output_filtered.to(
prefix_output.dtype)
prefix_lse[filtered_indices, :, :] = prefix_lse_filtered.to(
prefix_lse.dtype)
attn_out_lse_allgather = torch.stack(
attn_out_lse_list,
dim=0) # [pcp, batch_size, num_heads, head_size+1]
attn_out_allgather, attn_lse_allgather = torch.split(
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
def _load_kv_for_chunk(self, attn_metadata, kv_cache, context_starts_rank,
i, local_chunked_kv_lens, prefill_metadata, query):
prefix_output, prefix_lse = self._update_out_and_lse(
attn_out_allgather, attn_lse_allgather)
return prefix_output, prefix_lse
def _load_kv_for_chunk(self, attn_metadata, kv_cache,
local_chunked_kv_lens_rank, query):
cache_key = kv_cache[0]
cache_value = kv_cache[1]
num_heads = cache_key.size(2)
head_size = kv_cache[0].size(-1)
# 1. Load current query's history key-value
seq_lens_current_chunk = attn_metadata.query_lens.detach().clone()
num_requests = len(seq_lens_current_chunk)
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
context_starts_rank = torch.zeros(
num_requests, dtype=torch.int32, device=query.device
) if context_starts_rank is None else context_starts_rank
# Calculate tokens each rank should process per request
seq_lens_current_chunk_rank = torch.zeros_like(seq_lens_current_chunk,
dtype=torch.int32,
device=query.device)
total_toks = 0
for req_idx in range(num_requests):
if i >= len(local_chunked_kv_lens[req_idx]):
continue
n_computed_acc = local_chunked_kv_lens[req_idx][i]
total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank]
seq_lens_current_chunk_rank[req_idx] = n_computed_acc[
self.pcp_rank][self.dcp_rank]
if total_toks > 0:
key = torch.empty(total_toks,
num_heads,
head_size,
dtype=query.dtype,
device=query.device)
value = torch.empty(total_toks,
num_heads,
head_size,
dtype=query.dtype,
device=query.device)
total_toks = local_chunked_kv_lens_rank.sum()
key = torch.empty(total_toks,
num_heads,
head_size,
dtype=query.dtype,
device=query.device)
value = torch.empty(total_toks,
num_heads,
head_size,
dtype=query.dtype,
device=query.device)
if total_toks > 0:
torch_npu.atb.npu_paged_cache_load(
cache_key,
cache_value,
attn_metadata.prefill.block_tables,
seq_lens_current_chunk_rank.to(query.device),
seq_starts=
context_starts_rank, # slot offsets of current chunk in current iteration
local_chunked_kv_lens_rank,
seq_starts=attn_metadata.prefill.chunked_context.
starts, # slot offsets of current chunk in current iteration
key=key,
value=value,
)
else:
# If current rank has no tokens to process, create empty tensors
key = torch.empty(0,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
value = torch.empty(0,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
seq_lens_current_chunk_rank = torch.zeros(
(len(seq_lens_current_chunk), ),
dtype=torch.int32,
device=query.device)
for req_idx in range(num_requests):
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
if i >= len(local_chunked_kv_lens[req_idx]):
continue
context_starts_rank[req_idx] += local_chunked_kv_lens[req_idx][i][
self.pcp_rank][self.dcp_rank]
if self.dcp_size > 1:
req_dcp_sizes = extract_req_dcp_by_chunk_pcp(
local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank)
assert len(req_dcp_sizes) == num_requests and all(
len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes)
total_toks = np.sum(np.array(req_dcp_sizes))
kv_local = torch.cat([key, value], dim=-1)
head_dim = kv_local.size(-1)
kv_full = torch.empty((total_toks, num_heads, head_dim),
device=query.device,
dtype=query.dtype)
kv_full_list = [None for _ in range(self.dcp_size)]
dist.all_gather_object(kv_full_list,
kv_local,
group=self.dcp_group)
kv_full_list = [
kv for kv in kv_full_list if kv is not None and kv.numel() > 0
]
if len(kv_full_list) > 0:
kv_full = torch.cat(kv_full_list, dim=0)
key, value = kv_full.split([head_size, head_size], dim=-1)
if total_toks == 0:
return key, value, None
seq_lens_current_chunk_rank = torch.tensor(
np.sum(np.array(req_dcp_sizes), axis=1),
dtype=torch.int32,
device=query.device) # [reqs]
return key, value, seq_lens_current_chunk_rank
return key, value
def forward(
self,

View File

@@ -5,7 +5,6 @@ from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple,
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from torch import nn
from vllm.attention.backends.abstract import (AttentionBackend,
@@ -35,14 +34,11 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
# isort: off
from vllm_ascend.attention.utils import (
AscendCommonAttentionMetadata, extract_req_dcp_by_chunk_pcp,
filter_chunked_req_indices, maybe_save_kv_layer_to_connector,
split_decodes_and_prefills, trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
# isort: on
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
@@ -122,10 +118,13 @@ class AscendMLAPrefillMetadata:
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
mask_for_non_zero_chunk: Optional[list[bool]] = None
max_chunk_num: int = 0
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
Optional[list[int]]]]]]]] = None
# for mla DCP & PCP
padded_chunk_seq_lens_npu: torch.Tensor = None
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
padded_local_cu_seq_lens: torch.Tensor = None
cu_seq_lens_lst: Optional[list[list[int]]] = None
chunk_size: Optional[int] = None
attn_mask: torch.Tensor
query_lens: torch.Tensor
@@ -140,7 +139,6 @@ class AscendMLAPrefillMetadata:
sin: torch.Tensor = None
cos: torch.Tensor = None
pcp_metadata: Optional[AscendPCPMetadata] = None
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
@dataclass
@@ -285,6 +283,9 @@ class AscendMLAMetadataBuilder:
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size if prefill_context_parallel_enable(
) else 1
self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
@@ -292,16 +293,6 @@ class AscendMLAMetadataBuilder:
self.decode_threshold,
dtype=torch.uint8,
device=device)
self.seq_mask_pcp_buf = torch.empty(max_num_seqs *
self.decode_threshold,
self.pcp_size,
dtype=torch.uint8,
device=device)
self.seq_mask_dcp_buf = torch.empty(max_num_seqs *
self.decode_threshold,
self.dcp_size,
dtype=torch.uint8,
device=device)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
@@ -366,10 +357,6 @@ class AscendMLAMetadataBuilder:
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None
cp_kv_recover_idx_for_chunk = long_seq_metadata.cp_kv_recover_idx_for_chunk if long_seq_metadata else None
local_chunked_kv_lens = long_seq_metadata.local_chunked_kv_lens if long_seq_metadata else None
mask_for_non_zero_chunk = long_seq_metadata.mask_for_non_zero_chunk if long_seq_metadata else None
max_chunk_num = long_seq_metadata.max_chunk_num if long_seq_metadata else 0
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
@@ -468,19 +455,75 @@ class AscendMLAMetadataBuilder:
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata = \
if self.dcp_size * self.pcp_size > 1:
if num_computed_tokens_of_pcp_dcp is not None:
local_context_lens_allranks = torch.tensor(
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
).reshape(-1, self.dcp_size * self.pcp_size)
# Note(qcs): The max local context lengths
# padded to `cp_local_block_size`.
padded_local_context_lens_cpu = (cdiv(
context_lens_cpu,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
padded_local_max_context_chunk_across_ranks = (cdiv(
max_context_chunk,
self.cp_virtual_block_size,
) * self.cp_local_block_size)
local_chunk_starts = (
torch.arange(num_chunks,
dtype=torch.int32).unsqueeze(1).expand(
-1, num_prefills) *
padded_local_max_context_chunk_across_ranks)
local_chunk_ends = torch.min(
padded_local_context_lens_cpu.unsqueeze(0),
local_chunk_starts +
padded_local_max_context_chunk_across_ranks,
)
padded_local_chunk_seq_lens = (local_chunk_ends -
local_chunk_starts).clamp(
min=0)
padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(
padded_local_chunk_seq_lens,
dim=1,
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
dtype=torch.int32,
)
chunked_context_metadata = \
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
local_chunked_kv_lens=local_chunked_kv_lens,
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
max_chunk_num=max_chunk_num,
)
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=local_chunk_starts.to(device, non_blocking=True),
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
local_context_lens_allranks=local_context_lens_allranks.tolist(),
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
device, non_blocking=True
),
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
chunk_size=padded_local_max_context_chunk_across_ranks,
)
else:
chunked_context_metadata = \
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
@@ -502,7 +545,7 @@ class AscendMLAMetadataBuilder:
sin=sin,
cos=cos,
pcp_metadata=pcp_metadata,
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk)
)
decode_metadata = None
if num_decodes > 0:
@@ -516,7 +559,7 @@ class AscendMLAMetadataBuilder:
block_table = block_table[:num_decodes, ...]
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
if self.pcp_size > 1:
if self.pcp_size > 1 and self.decode_threshold > 1:
block_table = block_table.repeat_interleave(
self.decode_threshold, dim=0)
seq_lens_list = seq_lens.tolist()
@@ -921,26 +964,8 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_metadata = attn_metadata.prefill
if prefill_metadata is None or prefill_metadata.chunked_context is None:
return prefix_output, prefix_lse
local_chunked_kv_lens = prefill_metadata.chunked_context.local_chunked_kv_lens
mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk
max_chunk_num = prefill_metadata.chunked_context.max_chunk_num
if self.pcp_size * self.dcp_size > 1:
assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0
if self.pcp_size > 1:
prefix_output = torch.zeros(q_nope.shape[0],
self.num_heads,
self.v_head_dim,
dtype=q_nope.dtype,
device=q_nope.device)
prefix_lse = torch.zeros(self.num_heads,
q_pe.shape[0],
dtype=torch.float32,
device=q_pe.device)
iters = len(prefill_metadata.chunked_context.seq_tot)
if self.pcp_size * self.dcp_size > 1:
iters = max_chunk_num
current_seq_len = torch.tensor(prefill_metadata.query_lens,
dtype=torch.int32)
@@ -948,305 +973,97 @@ class AscendMLAImpl(MLAAttentionImpl):
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
# token -> request mapping for building per-token masks when CP>1
seq_len1 = torch.tensor(prefill_metadata.query_lens,
dtype=torch.int32,
device=q_nope.device)
seq_len1.mul_(
self.pcp_size) # q_full: already padded, divisible by cp_size
# Select mask: prefer CP prefill mask from metadata; fallback to cached prefill_mask; create if needed.
mask_local = None
if attn_metadata is not None and attn_metadata.prefill is not None and \
attn_metadata.prefill.pcp_metadata is not None and attn_metadata.prefill.pcp_metadata.pcp_prefill_mask is not None:
mask_local = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
else:
mask_local = self.prefill_mask
if mask_local is None:
mask_local = torch.triu(
torch.ones(512,
512,
device=q_nope.device,
dtype=q_nope.dtype), 1)
self.prefill_mask = mask_local
# Keep the causal mask; do not override to all-ones.
context_starts_rank = None
for i in range(iters):
if self.pcp_size * self.dcp_size > 1:
## DCP mode: each rank processes its own (cp,dcp) historical context slice per request dimension
num_requests = len(seq_len1)
assert num_requests == len(local_chunked_kv_lens)
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
context_starts_rank = torch.zeros(
num_requests, dtype=torch.int32, device=q_nope.device
) if context_starts_rank is None else context_starts_rank
toks = prefill_metadata.chunked_context.seq_tot[i]
# chunk_seq_lens will be padded when pcp&dcp
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
i]
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
i]
seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
dtype=q_nope.dtype,
device=q_nope.device)
k_pe = torch.empty(toks,
num_heads,
rope_dim,
dtype=q_nope.dtype,
device=q_nope.device)
## Calculate tokens each rank should process per request
seq_len2_rank = torch.zeros_like(seq_len1, dtype=torch.int32)
total_toks = 0
for req_idx in range(num_requests):
if i >= len(local_chunked_kv_lens[req_idx]):
continue
n_computed_acc = local_chunked_kv_lens[req_idx][i]
total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank]
seq_len2_rank[req_idx] = n_computed_acc[self.pcp_rank][
self.dcp_rank]
if total_toks > 0:
kv_c_normed = torch.empty(total_toks,
num_heads,
latent_kv_dim,
dtype=q_nope.dtype,
device=q_nope.device)
k_pe = torch.empty(total_toks,
num_heads,
rope_dim,
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
seq_len2_rank.to(q_nope.device),
seq_starts=
context_starts_rank, # slot offsets of current chunk in current iteration
key=kv_c_normed,
value=k_pe,
)
seq_len2 = seq_len2_rank.to(q_nope.device)
else:
# If current rank has no tokens to process, create empty tensors
kv_c_normed = torch.empty(0,
num_heads,
latent_kv_dim,
dtype=q_nope.dtype,
device=q_nope.device)
k_pe = torch.empty(0,
num_heads,
rope_dim,
dtype=q_nope.dtype,
device=q_nope.device)
seq_len2 = torch.zeros((len(seq_len1), ),
dtype=torch.int32,
device=q_nope.device)
seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()])
for req_idx in range(num_requests):
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
if i >= len(local_chunked_kv_lens[req_idx]):
continue
context_starts_rank[req_idx] += local_chunked_kv_lens[
req_idx][i][self.pcp_rank][self.dcp_rank]
else:
# Original logic: ChunkPrefill-only mode
toks = prefill_metadata.chunked_context.seq_tot[i]
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
if self.dcp_size * self.pcp_size > 1:
context_seq_len_npu = prefill_metadata.chunked_context.padded_chunk_seq_lens_npu[
i]
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
i]
seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
dtype=q_nope.dtype,
device=q_nope.device)
k_pe = torch.empty(toks,
num_heads,
rope_dim,
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
)
torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1)
if self.dcp_size > 1:
cache_kv_c_k_pe = get_dcp_group().all_gather(
cache_kv_c_k_pe, 0)
if self.pcp_size > 1:
cache_kv_c_k_pe = get_pcp_group().all_gather(
cache_kv_c_k_pe, 0)
if self.dcp_size * self.pcp_size > 1:
allgatered_kv_c_normed, allgatered_k_pe = cache_kv_c_k_pe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = self._reorg_kvcache(
allgatered_kv_c_normed,
allgatered_k_pe,
padded_local_chunk_seq_lens_lst=prefill_metadata.
chunked_context.padded_local_chunk_seq_lens[i],
local_context_lens_allranks=prefill_metadata.
chunked_context.local_context_lens_allranks,
sum_seq_len=prefill_metadata.chunked_context.
cu_seq_lens_lst[i][-1],
max_seq_len=prefill_metadata.chunked_context.
max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks,
)
kv_c_normed = kv_c_normed.squeeze()
if self.dcp_size > 1:
# DCP mode: first all_gather within DCP group, let each rank in CP group share complete sequence blocks
# Step 1: DCP all_gather latent
kv_c_k_pe_local = torch.cat(
[kv_c_normed, k_pe.squeeze()],
dim=-1) # [local_toks, latent_dim + rope_dim]
# Step 2: use all_gather_into_tensor_uneven (gather + cat)
req_dcp_sizes = extract_req_dcp_by_chunk_pcp(
local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank
) # need to know num tokens of each rank in dcp group before all_gather # [reqs, dcp]
assert len(req_dcp_sizes) == num_requests and all(
len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes)
total_toks = np.sum(np.array(req_dcp_sizes))
latent_rope_dim = kv_c_k_pe_local.size(-1)
kv_c_k_pe_full = torch.empty((total_toks, latent_rope_dim),
device=kv_c_k_pe_local.device,
dtype=kv_c_k_pe_local.dtype)
kv_c_k_pe_full_list = [None for _ in range(self.dcp_size)]
dist.all_gather_object(kv_c_k_pe_full_list,
kv_c_k_pe_local,
group=self.dcp_group)
kv_c_k_pe_full_list = [
kv_c_k_pe for kv_c_k_pe in kv_c_k_pe_full_list
if kv_c_k_pe is not None and kv_c_k_pe.numel() > 0
]
if len(kv_c_k_pe_full_list) > 0:
kv_c_k_pe_full = torch.cat(kv_c_k_pe_full_list, dim=0)
if len(kv_c_k_pe_full.shape) == 1:
assert total_toks == 1
kv_c_k_pe_full = kv_c_k_pe_full.unsqueeze(0)
assert kv_c_k_pe_full.shape[
0] == total_toks and kv_c_k_pe_full.shape[
1] == latent_rope_dim
kv_c_normed_full, k_pe_full = torch.split(
kv_c_k_pe_full, [latent_kv_dim, rope_dim], dim=-1)
# Step 3: process complete sequence with TP projection to get current rank's head slice
# Case that no kv_cache has been stored on this CP rank(after dcp all_gather), no need to do following computation.
if total_toks == 0:
continue
kv_nope = self.kv_b_proj(kv_c_normed_full)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe_full.unsqueeze(1).expand((*k_nope.shape[:-1], -1))
seq_len2 = torch.tensor(np.sum(np.array(req_dcp_sizes),
axis=1),
dtype=torch.int32,
device=q_nope.device) # [reqs]
seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()])
else:
# Non-DCP mode: use TP-split projection
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
if self.pcp_size > 1:
# Case that no kv_cache has been stored on this CP rank, no need to do following computation.
if torch.all(seq_len2 == 0).item():
continue
# PCP mode: first compute this rank's contribution to the chunk
if i == 0:
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=mask_local,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=prefix_output,
softmax_lse=prefix_lse)
continue
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=mask_local,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=prefix_output,
prev_lse=prefix_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=prefix_output,
softmax_lse=prefix_lse)
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
else:
assert not torch.all(context_seq_len == 0).item()
# compute this chunk block then update prefix tensors to keep shapes consistent
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=mask_local,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=prefix_output,
prev_lse=prefix_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=prefix_output,
softmax_lse=prefix_lse)
# CP dimension all_gather and fusion
if self.pcp_size > 1:
# filter non-zero chunk part of prefix_output
seq_len1_cpu = seq_len1.cpu()
filtered_indices = filter_chunked_req_indices(
seq_len1_cpu, mask_for_non_zero_chunk)
prefix_output_filtered = prefix_output[filtered_indices, :, :]
prefix_lse_filtered = prefix_lse[:, filtered_indices]
# normalize prefix LSE to [bs, heads, 1] for stable updates
prefix_lse_filtered_bt = prefix_lse_filtered.permute(
1, 0).unsqueeze(-1).contiguous(
) if prefix_lse_filtered is not None else None
out_lse_local = torch.cat(
[prefix_output_filtered, prefix_lse_filtered_bt], dim=-1)
out_lse_list = [
torch.empty_like(out_lse_local) for _ in range(self.pcp_size)
]
dist.all_gather(out_lse_list, out_lse_local, group=self.pcp_group)
prefix_output_filtered = None
prefix_lse_filtered_bt = None
for r in range(self.pcp_size):
out_lse_r = out_lse_list[r]
if torch.all(out_lse_r == 0).item():
continue
out_r, lse_r = torch.split(out_lse_r, [self.v_head_dim, 1],
dim=-1)
token_mask = torch.ones([out_r.size(0)],
dtype=torch.uint8,
device=out_r.device)
prefix_output_filtered, prefix_lse_filtered_bt = self._update_out_and_lse(
prefix_output_filtered, prefix_lse_filtered_bt, out_r,
lse_r, token_mask)
# convert lse back to [heads, bs]
assert prefix_output_filtered is not None and prefix_lse_filtered_bt is not None
prefix_lse_filtered = prefix_lse_filtered_bt.squeeze(-1).permute(
1, 0).contiguous()
prefix_output[filtered_indices, :, :] = prefix_output_filtered.to(
prefix_output.dtype)
prefix_lse[:, filtered_indices] = prefix_lse_filtered.to(
prefix_lse.dtype)
mask = self.prefill_mask
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=mask,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=prefix_output,
prev_lse=prefix_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=prefix_output,
softmax_lse=prefix_lse)
return prefix_output, prefix_lse
def _forward_prefill(
@@ -1814,8 +1631,7 @@ class AscendMLAImpl(MLAAttentionImpl):
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
output_head, head_lse = self._attention_with_mask_and_nomask(
output_head, lse_head = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
k_nope=k_nope,
@@ -1827,7 +1643,7 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=mask)
output_tail, tail_lse = self._attention_with_mask_and_nomask(
output_tail, lse_tail = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
k_nope=k_nope,
@@ -1840,86 +1656,15 @@ class AscendMLAImpl(MLAAttentionImpl):
mask=mask)
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
output = torch.index_select(
attn_output = torch.index_select(
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
attn_lse = torch.index_select(torch.cat([lse_head, lse_tail], dim=1),
1, q_full_idx)
# Synchronize and reorder LSE for subsequent chunked context accumulation
attn_lse = torch.cat([head_lse, tail_lse], dim=1)
attn_lse = attn_lse[:, q_full_idx]
output, _ = self._compute_prefill_context( \
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
# Post-processing: keep [tokens, H, V] shape and perform chunked context accumulation if needed
if attn_metadata.prefill is not None and \
attn_metadata.prefill.chunked_context is not None:
# q all_gather
q_nope_full = get_pcp_group().all_gather(q_nope.contiguous(), 0)
q_pe_full = get_pcp_group().all_gather(q_pe.contiguous(), 0)
q_nope_full = torch.index_select(
q_nope_full, 0,
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
q_pe_full = torch.index_select(
q_pe_full, 0,
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
attn_output_pre = output.view(num_tokens, self.num_heads,
self.v_head_dim)
attn_output_pre_full, attn_lse_full = self._compute_prefill_context(
q_nope_full,
q_pe_full,
kv_c_and_k_pe_cache,
self.qk_rope_head_dim,
attn_metadata,
None,
None,
)
# reorder back && extract output + lse result of each cp rank
inverse_idx = torch.argsort(
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
attn_output_pre_full = torch.index_select(attn_output_pre_full, 0,
inverse_idx)
attn_lse_full = torch.index_select(attn_lse_full, 1, inverse_idx)
attn_output_pre_new = attn_output_pre_full[
self.pcp_rank * num_tokens:(self.pcp_rank + 1) *
num_tokens, :, :]
attn_lse_new = attn_lse_full[:, self.pcp_rank *
num_tokens:(self.pcp_rank + 1) *
num_tokens]
# update(output_origin, output_new)
assert attn_output_pre_new.shape == attn_output_pre.shape and attn_lse_new.shape == attn_lse.shape
seq_len = torch.tensor(attn_metadata.prefill.query_lens,
dtype=torch.int32)
mask_for_non_zero_chunk = attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk
filtered_indices = filter_chunked_req_indices(
seq_len, mask_for_non_zero_chunk)
attn_output_pre_filtered = attn_output_pre[filtered_indices, :, :]
attn_lse_filtered = attn_lse[:, filtered_indices]
attn_output_pre_new = attn_output_pre_new[filtered_indices, :, :]
attn_lse_new = attn_lse_new[:, filtered_indices]
# normalize prefix LSE to [bs, heads, 1] for stable updates
attn_lse_filtered = attn_lse_filtered.permute(1, 0).unsqueeze(-1)
attn_lse_new = attn_lse_new.permute(1, 0).unsqueeze(-1)
token_mask = torch.ones([attn_lse_new.size(0)],
dtype=torch.uint8,
device=attn_lse_new.device)
attn_output_pre_filtered, attn_lse_filtered = self._update_out_and_lse(
attn_output_pre_filtered, attn_lse_filtered,
attn_output_pre_new, attn_lse_new, token_mask)
# convert lse back to [heads, bs]
attn_lse_filtered = attn_lse_filtered.squeeze(-1).permute(
1, 0).contiguous()
attn_output_pre[
filtered_indices, :, :] = attn_output_pre_filtered.to(
attn_output_pre.dtype)
attn_lse[:,
filtered_indices] = attn_lse_filtered.to(attn_lse.dtype)
attn_output_pre = attn_output_pre.to(q_nope.dtype)
output = attn_output_pre.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
else:
output = output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
output = output.reshape([num_tokens, self.num_heads * self.v_head_dim])
return output
@@ -2164,32 +1909,75 @@ class AscendMLAImpl(MLAAttentionImpl):
return attn_out_lse_list
# TODO use update op to replace this
def _update_out_and_lse(
def _reorg_kvcache(
self,
out: torch.Tensor,
lse: torch.Tensor,
block_out: torch.Tensor,
block_lse: torch.Tensor,
mask: torch.Tensor = None,
):
if out is None:
out = block_out.to(torch.float32)
lse = block_lse
else:
if mask is None:
mask = torch.ones([block_out.size(0)],
dtype=torch.uint8,
device=block_out.device)
out_mask = mask[:, None, None].expand_as(block_out)
lse_mask = mask[:, None, None].expand_as(block_lse)
block_out = block_out.to(torch.float32)
out_without_update = out.clone()
lse_without_update = lse.clone()
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
# mask
out = torch.where(out_mask, out, out_without_update)
lse = torch.where(lse_mask, lse, lse_without_update)
return out, lse
allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
padded_local_chunk_seq_lens_lst: list[int],
local_context_lens_allranks: list[list[int]],
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
reorg and unpad kvcache after cp local gather to tp layout for attn kernel.
e.g.
kv_c_normed in rank0 = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...]
kv_c_normed in rank1 = [T0_4, T0_5, pad, pad, T1_2, pad, ...]
allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...,
T0_4, T0_5, pad, pad, T1_2, pad, ...]
-> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5,
T1_0, T1_1, T1_2, ...]
Args:
padded_local_chunk_seq_lens_lst: local chunk context lengths
under current CP rank.
local_context_lens_allranks: local context lengths on each CP rank.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: the local padded max context chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for padded_local_chunk_seq_len, local_context_lens in zip(
padded_local_chunk_seq_lens_lst, local_context_lens_allranks):
cur_seq_len = 0
for rank, local_context_len in enumerate(local_context_lens):
# Note(qcs): We split the context into multiple chunks,
# depending on the size of the workspace.
# local_context in dcp0: |-----------------|
# local_context in dcp1: |--------------|
# n*padded_local_chunk: |-----|-----|-----|
# local_chunk_len in dcp1: |-----|-----|--|
# so we need update the last chunk length in dcp1.
local_chunk_len = min(
max(0, local_context_len - chunk_idx * chunk_size),
padded_local_chunk_seq_len,
)
if local_chunk_len != 0:
kv_c_segment = allgatered_kv_c_normed[rank * toks +
src_token_idx:rank *
toks +
src_token_idx +
local_chunk_len]
k_pe_segment = allgatered_k_pe[rank * toks +
src_token_idx:rank * toks +
src_token_idx +
local_chunk_len]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += local_chunk_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += padded_local_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
assert reorganized_k_pe.shape[0] == sum_seq_len
assert max_seq_len_check == max_seq_len
return reorganized_kv_c_normed, reorganized_k_pe

View File

@@ -20,13 +20,6 @@ class AscendPrefillContextParallelMetadata:
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[Optional[
list[int]]]]]]]] = None
mask_for_non_zero_chunk: Optional[List[bool]] = None
max_chunk_num: int = 0
q_head_idx_tensor: torch.Tensor = None
q_tail_idx_tensor: torch.Tensor = None
@@ -115,23 +108,6 @@ class AscendCommonAttentionMetadata:
AscendPrefillContextParallelMetadata] = None
def extract_req_dcp_by_chunk_pcp(lst,
chunk_idx,
dcp_size,
pcp_rank,
fill_value=0):
num_reqs = len(lst)
results: List[List[int]] = []
for i in range(num_reqs):
if len(lst[i]) == 0 or chunk_idx >= len(lst[i]):
# empty req or this req has no corresponding chunk, fill 0
results.append([fill_value] * dcp_size)
continue
dcp_values = lst[i][chunk_idx][pcp_rank]
results.append(dcp_values)
return results
def filter_chunked_req_indices(
seq_len: torch.Tensor,
mask_for_non_zero_chunk: Optional[List[bool]],

View File

@@ -29,7 +29,7 @@ from copy import deepcopy
from dataclasses import dataclass
from multiprocessing import Manager
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
Tuple, Union, cast)
Union, cast)
import numpy as np
import numpy.typing as npt
@@ -763,7 +763,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
backward_kwargs["mm_features"] = new_req_data.mm_features
# Create request state - PCP/DCP tracking will be computed below
req_state = CachedRequestState(
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt_embeds=new_req_data.prompt_embeds,
@@ -774,42 +774,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
local_chunked_kv_lens=None,
**backward_kwargs,
)
# Compute PCP/DCP tracking fields for chunked prefill
self.input_batch.local_chunked_kv_lens = [None] * self.max_num_reqs
if self.pcp_size * self.dcp_size > 1:
num_computed_tokens = new_req_data.num_computed_tokens
if num_computed_tokens > 0:
# Initialize with starting rank 0
temp_start_rank_dict = {req_id: (0, 0)}
# Compute token distribution for initial tokens
current_distribution = self.get_split_computed_tokens(
np.array([num_computed_tokens]),
request_ids=[req_id],
request_start_rank_dict=temp_start_rank_dict,
cp_kv_cache_interleave_size=self.parallel_config.
cp_kv_cache_interleave_size,
)[0]
# Update next_pcp_dcp_start_rank
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
req_id][0]
req_state.token_blank_in_last_blk = temp_start_rank_dict[
req_id][1]
req_state.local_chunked_kv_lens = [
copy.deepcopy(current_distribution)
]
else:
# No computed tokens yet
req_state.local_chunked_kv_lens = []
self.requests[req_id] = req_state
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._init_mrope_positions(self.requests[req_id])
@@ -826,44 +793,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
prev_num_computed_tokens = req_state.num_computed_tokens
req_state.num_computed_tokens = num_computed_tokens
# Compute PCP/DCP tracking fields for chunked prefill
if self.pcp_size * self.dcp_size > 1:
# If this is the first chunk, initialize tracking fields
if req_state.local_chunked_kv_lens is None:
req_state.local_chunked_kv_lens = []
# Compute tokens added in this chunk (not cumulative)
chunk_tokens = num_computed_tokens - prev_num_computed_tokens
if chunk_tokens > 0:
# Create a temporary dict with this request's starting rank
temp_start_rank_dict = {
req_id: (req_state.next_pcp_dcp_start_rank,
req_state.token_blank_in_last_blk)
}
# Compute distribution for this chunk only
chunk_distribution = self.get_split_computed_tokens(
np.array([chunk_tokens]),
request_ids=[req_id],
request_start_rank_dict=temp_start_rank_dict,
cp_kv_cache_interleave_size=self.parallel_config.
cp_kv_cache_interleave_size,
)[0]
# Update next_pcp_dcp_start_rank for this request
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
req_id][0]
req_state.token_blank_in_last_blk = temp_start_rank_dict[
req_id][1]
# Append this chunk's distribution to accumulation list
req_state.local_chunked_kv_lens.append(
copy.deepcopy(chunk_distribution))
if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
@@ -908,10 +839,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.block_table.append_row(
new_block_ids, req_index)
# Update PCP/DCP tracking fields in input_batch
self.input_batch.local_chunked_kv_lens[
req_index] = req_state.local_chunked_kv_lens
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
@@ -1477,7 +1404,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output,
decode_threshold=self.reorder_batch_threshold)
def generate_kv_idx(self, tokens, scheduler_output):
def generate_kv_idx(self, scheduler_output):
if not self.pcp_size > 1:
return
self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)]
@@ -1548,12 +1475,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
)
self.generate_kv_idx(tokens, scheduler_output)
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
if self.pcp_size > 1:
if not self.vllm_config.model_config.use_mla:
self.generate_kv_idx(scheduler_output)
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
tokens)
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
@@ -1905,18 +1834,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_reqs, scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens)
# prepare pcp meta data
# For chunked prefill, use num_scheduled_tokens instead of cumulative seq_lens
# to correctly calculate chunk_len in _generate_pcp_metadata
if self.vllm_config.scheduler_config.chunked_prefill_enabled and self.pcp_size > 1:
# In chunked prefill, seq_lens_for_chunk should be the current chunk size
seq_lens_for_chunk = torch.from_numpy(
num_scheduled_tokens[:num_reqs])
else:
# Normal mode: use cumulative sequence lengths
seq_lens_for_chunk = seq_lens_cpu
long_seq_metadata = self._generate_pcp_metadata(
total_num_scheduled_tokens, seq_lens_for_chunk, seq_lens_cpu)
total_num_scheduled_tokens)
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@@ -2842,6 +2761,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
self.query_start_loc_cpu[1:num_reqs +
1] = torch.Tensor(cu_num_tokens)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -2855,8 +2775,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
long_seq_metadata = self._generate_pcp_metadata(
num_tokens, self.seq_lens_cpu, self.seq_lens_cpu)
long_seq_metadata = self._generate_pcp_metadata(num_tokens)
if long_seq_metadata is not None:
pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
@@ -4411,7 +4330,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
all_positions_tensor.float().argsort().long(), non_blocking=True)
return pcp_tokens, positions, unpad_mask
def _get_pcp_local_seq_lens(
def _get_cp_local_seq_lens(
self,
seq_lens: torch.Tensor,
pcp_world_size: int = 1,
@@ -4439,139 +4358,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
[-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens
def get_split_computed_tokens(
self,
num_computed_tokens: np.ndarray,
request_ids: Optional[List[str]] = None,
request_start_rank_dict: Dict[str, tuple[
int, int]] = {}, # tuple: start_rank, tokens_blank_in_this_block
cp_kv_cache_interleave_size: int = 1
) -> list[Optional[list[Optional[list[int]]]]]:
"""Splits computed token counts across dcp and sp dimensions for distributed allocation.
Args:
num_computed_tokens: Number of tokens for each request (current chunk, not cumulative)
request_ids: Request IDs to track state
request_start_rank_dict: Dict mapping req_id to the starting rank for this chunk.
Will be updated with next starting rank after distribution.
Returns:
List of [pcp_size][dcp_size] distribution for each request
"""
self.pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
self.dcp_world_size = get_dcp_group().world_size
num_requests = len(num_computed_tokens)
assert request_start_rank_dict is not None and request_ids is not None and len(
request_ids) == num_requests
local_chunked_kv_lens = [[[0] * self.dcp_world_size
for _ in range(self.pcp_world_size)]
for _ in range(num_requests)]
total_ranks = self.pcp_world_size * self.dcp_world_size
for req_idx, (req_id, total_tokens) in enumerate(
zip(request_ids, num_computed_tokens)):
if total_tokens <= 0:
continue
# Get starting rank for this chunk
start_rank = 0
tokens_blank = 0
if request_start_rank_dict is not None:
start_rank, tokens_blank = request_start_rank_dict.get(
req_id, (0, 0))
if tokens_blank > 0: # need to continue writing in the last block of previous chunk
consumed_tokens = min(tokens_blank, total_tokens)
total_tokens -= consumed_tokens
tokens_blank -= consumed_tokens
pcp_idx = start_rank // self.dcp_world_size
dcp_idx = start_rank % self.dcp_world_size
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += consumed_tokens
if tokens_blank == 0:
start_rank = (start_rank + 1) % total_ranks
if total_tokens == 0:
request_start_rank_dict[req_id] = (start_rank,
tokens_blank)
continue
virtual_size = total_ranks * cp_kv_cache_interleave_size
base = int(total_tokens) // virtual_size
# Distribute base tokens to all ranks
for rank_idx in range(total_ranks):
pcp_idx = rank_idx // self.dcp_world_size
dcp_idx = rank_idx % self.dcp_world_size
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += base * cp_kv_cache_interleave_size
remainder = int(total_tokens) % virtual_size
if remainder == 0:
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
continue
remain_blocks = cdiv(remainder, cp_kv_cache_interleave_size)
assert remain_blocks > 0
# Distribute remainder tokens starting from start_rank
for i in range(remain_blocks):
rank = (start_rank + i) % total_ranks
pcp_idx = rank // self.dcp_world_size
dcp_idx = rank % self.dcp_world_size
if i < remain_blocks - 1 or remainder % cp_kv_cache_interleave_size == 0: # not last block or divisible
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += 1 * cp_kv_cache_interleave_size
tokens_blank = 0
else: # if last block and undivisible
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += remainder % cp_kv_cache_interleave_size
tokens_blank = cp_kv_cache_interleave_size - (
remainder % cp_kv_cache_interleave_size)
start_rank = (start_rank + remain_blocks - 1) % total_ranks
if tokens_blank == 0:
start_rank = (start_rank + 1) % total_ranks
# Update next starting rank for this request
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
return cast(List[Optional[List[Optional[List[int]]]]],
local_chunked_kv_lens)
def _get_chunked_req_mask_and_max_chunk(
self,
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
Optional[list[int]]]]]]]] = None
) -> Tuple[List[bool], int]:
"""
given 4-d list [req][chunk][pcp][dcp], return:
1. if each req has any chunk (list[bool])
2. max chunk num along all reqs (int)
"""
assert local_chunked_kv_lens is not None
if len(local_chunked_kv_lens) == 0:
return ([], 0)
mask_for_non_zero_chunk = [
len(req) > 0 for req in local_chunked_kv_lens if req is not None
]
max_chunk_num = max(
(len(req) for req in local_chunked_kv_lens if req is not None),
default=0)
return mask_for_non_zero_chunk, max_chunk_num
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens,
seq_lens_origin):
def _generate_pcp_metadata(self, total_num_scheduled_tokens):
# In dummy run num_reqs == 0, update it from seq_lens
num_reqs = self.input_batch.num_reqs or seq_lens.size(0)
num_reqs = self.input_batch.num_reqs or self.query_lens.size(0)
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
>= self.input_batch.num_prompt_tokens[:num_reqs])
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
local_chunked_kv_lens = self.input_batch.local_chunked_kv_lens[
num_decodes:num_reqs]
mask_for_non_zero_chunk, max_chunk_num = self._get_chunked_req_mask_and_max_chunk(
local_chunked_kv_lens)
long_seq_metadata = None
if self.pcp_size * self.dcp_size > 1:
decode_context_lens = self.input_batch.num_tokens[:num_decodes]
prefill_context_lens = self.input_batch.num_computed_tokens_cpu[
num_decodes:num_reqs]
context_lens = np.concatenate(
[decode_context_lens, prefill_context_lens])
num_computed_tokens_of_pcp_dcp = torch.zeros(
[
num_reqs * self.decode_threshold, self.pcp_size,
@@ -4584,8 +4384,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for decode_idx in range(self.decode_threshold):
num_computed_tokens_of_pcp_dcp[
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
self._get_pcp_local_seq_lens(
seq_lens_origin - decode_idx,
self._get_cp_local_seq_lens(
torch.tensor(context_lens),
self.pcp_size,
self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size,
@@ -4593,10 +4393,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
numpy(),
local_chunked_kv_lens=local_chunked_kv_lens,
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
max_chunk_num=max_chunk_num)
numpy())
if self.pcp_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
@@ -4607,7 +4404,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_req_offset = 0
q_head_chunk_id = self.pcp_rank
q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank
for i, seq_len in enumerate(seq_lens):
for i, seq_len in enumerate(self.query_lens):
if i < num_decodes:
continue
chunk_len = seq_len // 2

View File

@@ -73,12 +73,6 @@ class CachedRequestState:
lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
# pcp/dcp param
local_chunked_kv_lens: Optional[list[Optional[list[Optional[
list[int]]]]]] = None # Records computed tokens for each chunk
next_pcp_dcp_start_rank: int = 0 # Tracks next starting rank for round-robin distribution
token_blank_in_last_blk: int = 0 # if the last block is not full, how many future tokens can be stored
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
@@ -319,10 +313,6 @@ class InputBatch:
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None
# pcp/dcp parameters
self.local_chunked_kv_lens: list[Optional[list[Optional[list[Optional[
list[int]]]]]]] = [None] * max_num_reqs
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
@@ -395,9 +385,6 @@ class InputBatch:
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
# Add PCP/DCP tracking fields
self.local_chunked_kv_lens[req_index] = request.local_chunked_kv_lens
if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
@@ -693,8 +680,6 @@ class InputBatch:
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.local_chunked_kv_lens[
empty_index] = self.local_chunked_kv_lens[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]