Disaggregate prefill for kv cache register style (#950)
### What this PR does / why we need it?
This PR adopt `LLMDataDist` for kv cache register and `pull_blocks`
style disaggregate prefill implementation. The interface implementation
mainly follows the design of NIXL PR
https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953
.
This PR can be test with the following step:
- Generate the rank table for all machine.
- execute`toy_proxy.py` to launch the disaggregate prefill proxy server,
specify the prefill ip, port and the decode ip, port
- Run the prefill server and decode server.
- send the request to the disaggregate prefill proxy
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2
---------
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Signed-off-by: liziyu179 <3475441767@qq.com>
Signed-off-by: underfitc <hucong24@huawei.com>
Signed-off-by: zouyida2052 <zouyida@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: underfituu <hzhucong@163.com>
Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Co-authored-by: liziyu179 <3475441767@qq.com>
Co-authored-by: underfitc <hucong24@huawei.com>
Co-authored-by: zouyida2052 <zouyida@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: underfituu <hzhucong@163.com>
This commit is contained in:
@@ -252,7 +252,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
trace_flag: bool = True,
|
||||
@@ -262,8 +262,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache: shape = [2, num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
kv_cache: shape = [key_cache, value_cache]
|
||||
key_cache = [num_blocks, block_size,
|
||||
num_kv_heads, head_size]
|
||||
value_cache = [num_blocks, block_size,
|
||||
@@ -273,8 +272,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
use_kv_cache_int8 = kv_cache.numel(
|
||||
) > 0 and kv_cache[0].dtype == torch.int8
|
||||
use_kv_cache_int8 = len(
|
||||
kv_cache) > 0 and kv_cache[0].dtype == torch.int8
|
||||
if output is None:
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
@@ -314,7 +313,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
if len(kv_cache) > 1:
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
|
||||
@@ -62,7 +62,7 @@ class AscendAttentionTorchairBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads * head_size)
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_bsh_kv_cache_shape(
|
||||
@@ -71,7 +71,7 @@ class AscendAttentionTorchairBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads * head_size)
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
@@ -648,12 +649,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
def _compute_prefill_context(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
|
||||
rope_dim: int,
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
):
|
||||
assert len(kv_c_and_k_pe_cache) > 1
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
||||
return prefix_output, prefix_lse
|
||||
@@ -663,21 +665,22 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
q_nope = query[..., :self.qk_nope_head_dim]
|
||||
|
||||
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
||||
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
|
||||
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
|
||||
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
|
||||
cache_kv_c = kv_c_and_k_pe_cache[0]
|
||||
cache_k_pe = kv_c_and_k_pe_cache[1]
|
||||
num_heads = cache_k_pe.size(2)
|
||||
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
|
||||
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
||||
seq_len = torch.stack([seq_len1, seq_len2])
|
||||
kv_c_normed = torch.empty(toks,
|
||||
kv_c_and_k_pe_cache.size(2),
|
||||
num_heads,
|
||||
latent_kv_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
k_pe = torch.empty(toks,
|
||||
kv_c_and_k_pe_cache.size(2),
|
||||
num_heads,
|
||||
rope_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
@@ -727,10 +730,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
query: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.prefill is not None
|
||||
assert len(kv_c_and_k_pe_cache) > 1
|
||||
|
||||
num_tokens = query.size(0)
|
||||
attn_output = torch.empty(num_tokens,
|
||||
@@ -923,19 +927,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
q_pe: torch.Tensor,
|
||||
k_nope: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
enable_multistream_mla: bool = False,
|
||||
) -> torch.Tensor:
|
||||
decode_meta = attn_metadata.decode
|
||||
assert decode_meta is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
num_tokens = q.size(0)
|
||||
attn_output = torch.empty(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
num_tokens = q_nope.size(0)
|
||||
if self.running_in_graph:
|
||||
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
|
||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
@@ -994,16 +992,35 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
actual_seq_lengths_kv=decode_meta.seq_lens_list,
|
||||
)
|
||||
else:
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=q,
|
||||
key_cache=kv_c_and_k_pe_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.decode.block_table, # type:ignore
|
||||
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
|
||||
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
|
||||
# public available
|
||||
assert len(kv_c_and_k_pe_cache) > 1
|
||||
if envs.VLLM_ASCEND_MLA_PA:
|
||||
attn_output = torch_npu.atb.npu_multi_head_latent_attention(
|
||||
q_nope, q_pe, kv_c_and_k_pe_cache[0],
|
||||
kv_c_and_k_pe_cache[1], attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens, self.num_heads, self.scale,
|
||||
self.num_kv_heads)
|
||||
else:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
attn_output = torch.empty(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
k_cache = torch.cat(
|
||||
[kv_c_and_k_pe_cache[0], kv_c_and_k_pe_cache[1]], dim=-1)
|
||||
torch_npu._npu_paged_attention_mla(
|
||||
query=q,
|
||||
key_cache=k_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.decode.
|
||||
block_table, # type:ignore
|
||||
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
||||
mla_vheadsize=self.kv_lora_rank,
|
||||
out=attn_output)
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
return self._v_up_proj_and_o_proj(attn_output,
|
||||
@@ -1020,7 +1037,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: M,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
enable_multistream_mla: bool = False,
|
||||
@@ -1151,8 +1168,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
prefill_q_pe.contiguous(),
|
||||
prefill_k_pe,
|
||||
max_seq_len=attn_metadata.prefill.max_seq_lens)
|
||||
|
||||
assert len(
|
||||
kv_cache
|
||||
) > 1, "the number of kv cache should be greater than 1, namely (nope_cache and rope_cache)"
|
||||
if self.torchair_graph_enabled:
|
||||
if len(kv_cache) > 0 and kv_cache[0].numel(
|
||||
if kv_cache[0].numel(
|
||||
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
slots = attn_metadata.slot_mapping
|
||||
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
||||
@@ -1162,16 +1183,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
key_cache=kv_cache[0],
|
||||
value_cache=kv_cache[1],
|
||||
slot_indices=slots)
|
||||
elif kv_cache.numel() > 0:
|
||||
key = torch.cat([
|
||||
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
|
||||
k_pe
|
||||
],
|
||||
dim=2)
|
||||
torch_npu._npu_reshape_and_cache_siso(
|
||||
key=key,
|
||||
key_cache=kv_cache,
|
||||
slot_indices=attn_metadata.slot_mapping.flatten())
|
||||
else:
|
||||
kv_c_normed = kv_c_normed.view(
|
||||
[num_actual_toks, self.num_kv_heads, -1])
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
key_cache=kv_cache[0],
|
||||
value_cache=kv_cache[1],
|
||||
slot_indices=attn_metadata.slot_mapping)
|
||||
if has_prefill:
|
||||
# FIX: aicore move should be also placed on the comm stream in dbo,
|
||||
# otherwise it may affect the accuracy
|
||||
|
||||
Reference in New Issue
Block a user