[long_seq_Feat] support chunk prefill (#4158)
### What this PR does / why we need it?
1、qwen GQA attention_v1 optim
2、DeepSeek MLA refactor, all gather q -> all gather kv
3、modelrunner refactor for chunk prefill, we remove some code not use
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
This commit is contained in:
@@ -44,7 +44,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
extract_req_dcp_by_chunk_pcp,
|
||||
filter_chunked_req_indices,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
@@ -169,10 +168,10 @@ class AscendMetadataForPrefill:
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
actual_chunk_seq_lengths: list[int]
|
||||
mask_for_non_zero_chunk: Optional[list[bool]] = None
|
||||
max_chunk_num: int = 0
|
||||
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
|
||||
Optional[list[int]]]]]]]] = None
|
||||
actual_seq_lengths_kv: list[int]
|
||||
starts: torch.Tensor
|
||||
chunked_req_mask: Optional[list[bool]] = None
|
||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
||||
|
||||
@@ -286,25 +285,7 @@ class AscendAttentionMetadataBuilder:
|
||||
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||
self.block_size - 1) // self.block_size
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(8 * self.model_config.max_model_len,
|
||||
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * self.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def reorder_batch(self, input_batch,
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
@@ -385,6 +366,8 @@ class AscendAttentionMetadataBuilder:
|
||||
prefill_metadata = None
|
||||
decode_metadata = None
|
||||
if common_long_seq_metadata is not None:
|
||||
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||
assert num_computed_tokens_of_pcp_dcp is not None
|
||||
chunked_context_metadata = None
|
||||
if num_prefills > 0:
|
||||
query_lens = query_lens[num_decode_tokens:]
|
||||
@@ -394,18 +377,39 @@ class AscendAttentionMetadataBuilder:
|
||||
pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
local_context_lens_allranks = torch.tensor(
|
||||
num_computed_tokens_of_pcp_dcp
|
||||
)[num_decodes:num_reqs].to(
|
||||
self.device).to(dtype=torch.int32)
|
||||
local_chunked_kv_lens_rank = local_context_lens_allranks[:,
|
||||
self
|
||||
.
|
||||
pcp_rank,
|
||||
self
|
||||
.
|
||||
dcp_rank]
|
||||
actual_seq_lengths_kv = torch.cumsum(
|
||||
local_chunked_kv_lens_rank, dim=0).tolist()
|
||||
chunked_req_mask = self._get_chunked_req_mask(
|
||||
local_context_lens_allranks)
|
||||
local_chunk_starts = torch.zeros(
|
||||
(len(local_context_lens_allranks)),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk
|
||||
kv_inverse_idx_for_chunk = torch.argsort(
|
||||
cp_kv_recover_idx_for_chunk.to(torch.float32)
|
||||
) if cp_kv_recover_idx_for_chunk is not None else None
|
||||
|
||||
chunked_context_metadata = \
|
||||
AscendMetadataForPrefill.ChunkedContextMetadata(
|
||||
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
|
||||
mask_for_non_zero_chunk=common_long_seq_metadata.mask_for_non_zero_chunk,
|
||||
local_chunked_kv_lens=common_long_seq_metadata.local_chunked_kv_lens,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
chunked_req_mask=chunked_req_mask,
|
||||
starts=local_chunk_starts,
|
||||
local_context_lens_allranks=local_context_lens_allranks,
|
||||
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
|
||||
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
|
||||
max_chunk_num=common_long_seq_metadata.max_chunk_num
|
||||
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk
|
||||
)
|
||||
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
||||
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
|
||||
@@ -445,8 +449,6 @@ class AscendAttentionMetadataBuilder:
|
||||
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0))
|
||||
|
||||
if num_decodes > 0:
|
||||
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||
assert num_computed_tokens_of_pcp_dcp is not None
|
||||
num_computed_tokens_array = np.array(
|
||||
num_computed_tokens_of_pcp_dcp)
|
||||
num_computed_tokens_array = num_computed_tokens_array[:
|
||||
@@ -483,6 +485,19 @@ class AscendAttentionMetadataBuilder:
|
||||
decode_meta=decode_metadata)
|
||||
return attn_metadata
|
||||
|
||||
def _get_chunked_req_mask(self, local_context_lens_allranks) -> List[bool]:
|
||||
"""
|
||||
given 4-d list [req][pcp][dcp], return:
|
||||
1. if each req has any chunk (list[bool])
|
||||
"""
|
||||
assert local_context_lens_allranks is not None
|
||||
if len(local_context_lens_allranks) == 0:
|
||||
return []
|
||||
chunked_req_mask = [(req.sum() > 0).item()
|
||||
for req in local_context_lens_allranks
|
||||
if req is not None]
|
||||
return chunked_req_mask
|
||||
|
||||
def build_for_graph_capture(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
@@ -1205,11 +1220,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
|
||||
attn_lse_full_chunk = attn_lse_full_chunk[
|
||||
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
|
||||
|
||||
assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape
|
||||
seq_len = attn_metadata.query_lens.detach().clone()
|
||||
filtered_indices = filter_chunked_req_indices(
|
||||
seq_len,
|
||||
attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk)
|
||||
seq_len, attn_metadata.prefill.chunked_context.chunked_req_mask)
|
||||
|
||||
attn_output_prefill_filtered = current_attn_output_prefill[
|
||||
filtered_indices, :, :]
|
||||
@@ -1221,18 +1236,23 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_output_filtered = self._npu_attn_out_lse_update(
|
||||
attn_lse_prefill_filtered, attn_lse_full_chunk,
|
||||
attn_output_prefill_filtered, attn_output_full_chunk)
|
||||
|
||||
current_attn_output_prefill[
|
||||
filtered_indices, :, :] = attn_output_filtered.to(
|
||||
current_attn_output_prefill.dtype)
|
||||
|
||||
def _prefill_query_all_gather(self, attn_metadata, prefill_query):
|
||||
prefill_query_all = get_pcp_group().all_gather(prefill_query.contiguous(),
|
||||
0) \
|
||||
if self.pcp_size > 1 else prefill_query
|
||||
prefill_query_all = torch.index_select(prefill_query_all,
|
||||
if self.dcp_size > 1:
|
||||
prefill_query = get_dcp_group().all_gather(prefill_query, 1)
|
||||
|
||||
if self.pcp_size > 1:
|
||||
prefill_query = get_pcp_group().all_gather(prefill_query, 0)
|
||||
|
||||
prefill_query_all = torch.index_select(prefill_query,
|
||||
0,
|
||||
attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk) \
|
||||
if self.pcp_size > 1 else prefill_query_all
|
||||
if self.pcp_size > 1 else prefill_query
|
||||
|
||||
return prefill_query_all
|
||||
|
||||
def _compute_prefill_context(self, query: torch.Tensor,
|
||||
@@ -1243,217 +1263,132 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
assert attn_metadata.prefill is not None
|
||||
assert attn_metadata.prefill.chunked_context is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
local_chunked_kv_lens = attn_metadata.prefill.chunked_context.local_chunked_kv_lens
|
||||
mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk
|
||||
max_chunk_num = prefill_metadata.chunked_context.max_chunk_num
|
||||
local_chunked_kv_lens = prefill_metadata.chunked_context.local_context_lens_allranks
|
||||
assert local_chunked_kv_lens is not None
|
||||
|
||||
assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0
|
||||
local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank,
|
||||
self.dcp_rank]
|
||||
|
||||
iters = max_chunk_num
|
||||
# Keep the causal mask; do not override to all-ones. [req_id][chunk_id][cp-rank][dcp_rank]
|
||||
context_starts_rank = None
|
||||
|
||||
prefix_output_list = []
|
||||
prefix_lse_list = []
|
||||
for i in range(iters):
|
||||
key, value, seq_lens_current_chunk_rank = self._load_kv_for_chunk(
|
||||
attn_metadata, kv_cache, context_starts_rank, i,
|
||||
local_chunked_kv_lens, prefill_metadata, query)
|
||||
|
||||
# 2. Attention computation
|
||||
if seq_lens_current_chunk_rank is None or torch.all(
|
||||
seq_lens_current_chunk_rank == 0).item():
|
||||
prefix_output = torch.full(
|
||||
(query.size(0), self.num_heads, self.head_size),
|
||||
fill_value=0,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
prefix_lse = torch.full((query.size(0), self.num_heads, 1),
|
||||
fill_value=0,
|
||||
dtype=torch.float32,
|
||||
device=query.device)
|
||||
else:
|
||||
|
||||
actual_seq_lengths_kv = torch.cumsum(
|
||||
seq_lens_current_chunk_rank, dim=0).tolist()
|
||||
prefix_output, prefix_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="TND", #
|
||||
atten_mask=None,
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
actual_seq_lengths=attn_metadata.prefill.chunked_context.
|
||||
actual_chunk_seq_lengths)
|
||||
prefix_output_list.append(prefix_output)
|
||||
prefix_lse_list.append(prefix_lse)
|
||||
|
||||
# 3. update attn-out & lse
|
||||
prefix_output, prefix_lse = self._update_attn_out_lse_in_chunks(
|
||||
prefix_output_list, prefix_lse_list)
|
||||
|
||||
self._update_attn_out_lse_in_pcp(attn_metadata, prefix_output,
|
||||
prefix_lse)
|
||||
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _update_attn_out_lse_in_chunks(self, prefix_output_list,
|
||||
prefix_lse_list):
|
||||
# update output and lse
|
||||
if len(prefix_output_list) > 1:
|
||||
prefix_output, prefix_lse = self._update_out_and_lse(
|
||||
torch.stack(prefix_output_list, dim=0),
|
||||
torch.stack(prefix_lse_list, dim=0))
|
||||
key, value = self._load_kv_for_chunk(attn_metadata, kv_cache,
|
||||
local_chunked_kv_lens_rank, query)
|
||||
if self.dcp_size > 1:
|
||||
num_heads = self.num_heads * self.dcp_size
|
||||
else:
|
||||
prefix_output = prefix_output_list[0]
|
||||
prefix_lse = prefix_lse_list[0]
|
||||
num_heads = self.num_heads
|
||||
|
||||
prefix_chunk_output = torch.full(
|
||||
(query.size(0), num_heads, self.head_size),
|
||||
fill_value=0,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
prefix_chunk_lse = torch.full((query.size(0), num_heads, 1),
|
||||
fill_value=-torch.inf,
|
||||
dtype=torch.float32,
|
||||
device=query.device)
|
||||
|
||||
if not torch.all(local_chunked_kv_lens_rank == 0).item():
|
||||
prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="TND",
|
||||
atten_mask=None,
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths_kv=prefill_metadata.chunked_context.
|
||||
actual_seq_lengths_kv,
|
||||
actual_seq_lengths=attn_metadata.prefill.chunked_context.
|
||||
actual_chunk_seq_lengths)
|
||||
|
||||
prefix_output, prefix_lse = self._update_chunk_attn_out_lse(
|
||||
prefix_chunk_output, prefix_chunk_lse)
|
||||
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _update_attn_out_lse_in_pcp(self, attn_metadata, prefix_output,
|
||||
prefix_lse):
|
||||
def _update_chunk_attn_out_lse(self, prefix_chunk_output,
|
||||
prefix_chunk_lse):
|
||||
# CP dimension all_gather and fusion
|
||||
if self.pcp_size > 1:
|
||||
# filter non-zero chunk part of prefix_output
|
||||
current_seq_lens = attn_metadata.query_lens.detach().clone()
|
||||
current_seq_lens.mul_(self.pcp_size) # q_full
|
||||
current_seq_lens_cpu = current_seq_lens.cpu()
|
||||
filtered_indices = filter_chunked_req_indices(
|
||||
current_seq_lens_cpu,
|
||||
attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk)
|
||||
prefix_output_filtered = prefix_output[filtered_indices, :, :]
|
||||
prefix_lse_filtered = prefix_lse[filtered_indices, :, :]
|
||||
chunk_attn_out_lse = torch.cat([prefix_chunk_output, prefix_chunk_lse],
|
||||
dim=-1)
|
||||
|
||||
out_lse_local = torch.cat(
|
||||
[prefix_output_filtered, prefix_lse_filtered], dim=-1)
|
||||
if self.dcp_size > 1:
|
||||
chunk_attn_out_lse = chunk_attn_out_lse.permute([1, 2,
|
||||
0]).contiguous()
|
||||
attn_out_lse_all2all = torch.empty_like(chunk_attn_out_lse)
|
||||
dist.all_to_all_single(attn_out_lse_all2all,
|
||||
chunk_attn_out_lse,
|
||||
group=self.dcp_group)
|
||||
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
|
||||
if self.pcp_size > 1:
|
||||
chunk_attn_out_lse = attn_out_lse_all2all.contiguous()
|
||||
|
||||
attn_out_lse_list = list(
|
||||
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
|
||||
|
||||
if self.pcp_size > 1:
|
||||
attn_out_lse_list = [
|
||||
torch.empty_like(out_lse_local) for _ in range(self.pcp_size)
|
||||
torch.empty_like(chunk_attn_out_lse)
|
||||
for _ in range(self.pcp_size)
|
||||
]
|
||||
dist.all_gather(attn_out_lse_list,
|
||||
out_lse_local,
|
||||
chunk_attn_out_lse,
|
||||
group=self.pcp_group)
|
||||
|
||||
attn_out_lse_allgather = torch.stack(
|
||||
attn_out_lse_list,
|
||||
dim=0) # [pcp, batch_size, num_heads, head_size+1]
|
||||
attn_out_allgather, attn_lse_allgather = torch.split(
|
||||
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
|
||||
prefix_output_filtered, prefix_lse_filtered = self._update_out_and_lse(
|
||||
attn_out_allgather, attn_lse_allgather)
|
||||
if self.dcp_size > 1 and self.pcp_size > 1:
|
||||
attn_out_lse_list_pcp_dcp = []
|
||||
for s in attn_out_lse_list:
|
||||
attn_out_lse_list_split = list(
|
||||
torch.chunk(s, self.dcp_size, dim=1))
|
||||
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
|
||||
attn_out_lse_list = attn_out_lse_list_pcp_dcp
|
||||
|
||||
prefix_output[filtered_indices, :, :] = prefix_output_filtered.to(
|
||||
prefix_output.dtype)
|
||||
prefix_lse[filtered_indices, :, :] = prefix_lse_filtered.to(
|
||||
prefix_lse.dtype)
|
||||
attn_out_lse_allgather = torch.stack(
|
||||
attn_out_lse_list,
|
||||
dim=0) # [pcp, batch_size, num_heads, head_size+1]
|
||||
attn_out_allgather, attn_lse_allgather = torch.split(
|
||||
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
|
||||
|
||||
def _load_kv_for_chunk(self, attn_metadata, kv_cache, context_starts_rank,
|
||||
i, local_chunked_kv_lens, prefill_metadata, query):
|
||||
prefix_output, prefix_lse = self._update_out_and_lse(
|
||||
attn_out_allgather, attn_lse_allgather)
|
||||
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _load_kv_for_chunk(self, attn_metadata, kv_cache,
|
||||
local_chunked_kv_lens_rank, query):
|
||||
cache_key = kv_cache[0]
|
||||
cache_value = kv_cache[1]
|
||||
num_heads = cache_key.size(2)
|
||||
head_size = kv_cache[0].size(-1)
|
||||
|
||||
# 1. Load current query's history key-value
|
||||
seq_lens_current_chunk = attn_metadata.query_lens.detach().clone()
|
||||
num_requests = len(seq_lens_current_chunk)
|
||||
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||
context_starts_rank = torch.zeros(
|
||||
num_requests, dtype=torch.int32, device=query.device
|
||||
) if context_starts_rank is None else context_starts_rank
|
||||
# Calculate tokens each rank should process per request
|
||||
seq_lens_current_chunk_rank = torch.zeros_like(seq_lens_current_chunk,
|
||||
dtype=torch.int32,
|
||||
device=query.device)
|
||||
total_toks = 0
|
||||
for req_idx in range(num_requests):
|
||||
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||
continue
|
||||
n_computed_acc = local_chunked_kv_lens[req_idx][i]
|
||||
total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank]
|
||||
seq_lens_current_chunk_rank[req_idx] = n_computed_acc[
|
||||
self.pcp_rank][self.dcp_rank]
|
||||
if total_toks > 0:
|
||||
key = torch.empty(total_toks,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
value = torch.empty(total_toks,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
total_toks = local_chunked_kv_lens_rank.sum()
|
||||
|
||||
key = torch.empty(total_toks,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
value = torch.empty(total_toks,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
if total_toks > 0:
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_key,
|
||||
cache_value,
|
||||
attn_metadata.prefill.block_tables,
|
||||
seq_lens_current_chunk_rank.to(query.device),
|
||||
seq_starts=
|
||||
context_starts_rank, # slot offsets of current chunk in current iteration
|
||||
local_chunked_kv_lens_rank,
|
||||
seq_starts=attn_metadata.prefill.chunked_context.
|
||||
starts, # slot offsets of current chunk in current iteration
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
else:
|
||||
# If current rank has no tokens to process, create empty tensors
|
||||
key = torch.empty(0,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
value = torch.empty(0,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
seq_lens_current_chunk_rank = torch.zeros(
|
||||
(len(seq_lens_current_chunk), ),
|
||||
dtype=torch.int32,
|
||||
device=query.device)
|
||||
for req_idx in range(num_requests):
|
||||
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||
continue
|
||||
context_starts_rank[req_idx] += local_chunked_kv_lens[req_idx][i][
|
||||
self.pcp_rank][self.dcp_rank]
|
||||
if self.dcp_size > 1:
|
||||
req_dcp_sizes = extract_req_dcp_by_chunk_pcp(
|
||||
local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank)
|
||||
|
||||
assert len(req_dcp_sizes) == num_requests and all(
|
||||
len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes)
|
||||
total_toks = np.sum(np.array(req_dcp_sizes))
|
||||
kv_local = torch.cat([key, value], dim=-1)
|
||||
head_dim = kv_local.size(-1)
|
||||
kv_full = torch.empty((total_toks, num_heads, head_dim),
|
||||
device=query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
kv_full_list = [None for _ in range(self.dcp_size)]
|
||||
dist.all_gather_object(kv_full_list,
|
||||
kv_local,
|
||||
group=self.dcp_group)
|
||||
kv_full_list = [
|
||||
kv for kv in kv_full_list if kv is not None and kv.numel() > 0
|
||||
]
|
||||
|
||||
if len(kv_full_list) > 0:
|
||||
kv_full = torch.cat(kv_full_list, dim=0)
|
||||
key, value = kv_full.split([head_size, head_size], dim=-1)
|
||||
if total_toks == 0:
|
||||
return key, value, None
|
||||
seq_lens_current_chunk_rank = torch.tensor(
|
||||
np.sum(np.array(req_dcp_sizes), axis=1),
|
||||
dtype=torch.int32,
|
||||
device=query.device) # [reqs]
|
||||
return key, value, seq_lens_current_chunk_rank
|
||||
return key, value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user