Disaggregate prefill for kv cache register style (#950)

### What this PR does / why we need it?
This PR adopt `LLMDataDist` for kv cache register and `pull_blocks`
style disaggregate prefill implementation. The interface implementation
mainly follows the design of NIXL PR
https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953
.

This PR can be test with the following step:
- Generate the rank table for all machine.
- execute`toy_proxy.py` to launch the disaggregate prefill proxy server,
specify the prefill ip, port and the decode ip, port
- Run the prefill server and decode server.
- send the request to the disaggregate prefill proxy

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Signed-off-by: liziyu179 <3475441767@qq.com>
Signed-off-by: underfitc <hucong24@huawei.com>
Signed-off-by: zouyida2052 <zouyida@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: underfituu <hzhucong@163.com>
Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Co-authored-by: liziyu179 <3475441767@qq.com>
Co-authored-by: underfitc <hucong24@huawei.com>
Co-authored-by: zouyida2052 <zouyida@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: underfituu <hzhucong@163.com>
This commit is contained in:
Pleaplusone
2025-07-26 17:15:47 +08:00
committed by GitHub
parent 17a430f7b8
commit df0ec55162
28 changed files with 2833 additions and 144 deletions

View File

@@ -252,7 +252,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
kv_cache: Tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
trace_flag: bool = True,
@@ -262,8 +262,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size,
num_kv_heads, head_size]
kv_cache: shape = [key_cache, value_cache]
key_cache = [num_blocks, block_size,
num_kv_heads, head_size]
value_cache = [num_blocks, block_size,
@@ -273,8 +272,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
shape = [batch_size * seq_len, num_heads, head_size]
"""
num_tokens = query.shape[0]
use_kv_cache_int8 = kv_cache.numel(
) > 0 and kv_cache[0].dtype == torch.int8
use_kv_cache_int8 = len(
kv_cache) > 0 and kv_cache[0].dtype == torch.int8
if output is None:
output = torch.empty(num_tokens,
self.num_heads,
@@ -314,7 +313,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
# TODO: Remove this contiguous in the future.
value = value.contiguous()
if kv_cache.numel() > 0:
if len(kv_cache) > 1:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping