Disaggregate prefill for kv cache register style (#950)

### What this PR does / why we need it?
This PR adopt `LLMDataDist` for kv cache register and `pull_blocks`
style disaggregate prefill implementation. The interface implementation
mainly follows the design of NIXL PR
https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953
.

This PR can be test with the following step:
- Generate the rank table for all machine.
- execute`toy_proxy.py` to launch the disaggregate prefill proxy server,
specify the prefill ip, port and the decode ip, port
- Run the prefill server and decode server.
- send the request to the disaggregate prefill proxy

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Signed-off-by: liziyu179 <3475441767@qq.com>
Signed-off-by: underfitc <hucong24@huawei.com>
Signed-off-by: zouyida2052 <zouyida@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: underfituu <hzhucong@163.com>
Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Co-authored-by: liziyu179 <3475441767@qq.com>
Co-authored-by: underfitc <hucong24@huawei.com>
Co-authored-by: zouyida2052 <zouyida@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: underfituu <hzhucong@163.com>
This commit is contained in:
Pleaplusone
2025-07-26 17:15:47 +08:00
committed by GitHub
parent 17a430f7b8
commit df0ec55162
28 changed files with 2833 additions and 144 deletions

View File

@@ -17,7 +17,9 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
#
import copy
import gc
import math
import os
import time
import types
@@ -37,9 +39,12 @@ from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group)
from vllm.forward_context import set_forward_context
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -342,6 +347,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
# kv role
self.is_kv_producer = False
if vllm_config.kv_transfer_config is not None:
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
@@ -908,7 +918,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata,
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]:
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray,
Optional[set[str]], Optional[set[str]]]:
# Check input valid
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
@@ -1144,6 +1155,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config,
num_tokens=num_input_tokens):
with ProfileExecuteDuration().capture_async("forward"):
self.maybe_setup_kv_connector(scheduler_output)
model_kwargs = {}
if self.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.kv_caches
@@ -1174,6 +1186,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
**model_kwargs,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = self.get_finished_kv_transfer(
scheduler_output)
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
@@ -1203,7 +1218,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
num_scheduled_tokens)
num_scheduled_tokens, finished_sending, finished_recving)
def _get_cumsum_and_arange(
self,
@@ -1436,12 +1451,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
"prepare input and forward"):
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if not has_kv_transfer_group():
logger.debug(
"skip this step for we receive the data from remote disaggregate prefill node"
)
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
(attn_metadata, hidden_states, spec_decode_metadata, positions,
num_scheduled_tokens, logits_indices, aux_hidden_states,
num_scheduled_tokens_np) = (self._process_reqs(
scheduler_output, intermediate_tensors))
num_scheduled_tokens_np, finished_sending,
finished_recving) = (self._process_reqs(scheduler_output,
intermediate_tensors))
with ProfileExecuteDuration().capture_async("post process"):
# Broadcast PP output for external_launcher (torchrun)
@@ -1593,6 +1614,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states,
)
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
@@ -1601,6 +1625,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
)
durations = ProfileExecuteDuration().pop_captured_sync()
@@ -1615,6 +1641,49 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return model_runner_output
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfer(scheduler_output))
# For the case of no forward caused by receiving remote kv,
# one round of dummy inference is necessary
# to prevent hang over the collective calls.
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfer(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
@torch.inference_mode()
def _dummy_run(
self,
@@ -1633,6 +1702,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
# Force dummy run on prefill stage when this node is deemed as kv producer.
if self.is_kv_producer:
with_prefill = True
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
@@ -1899,9 +1972,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_config = kv_cache_config
import torch_npu
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
) else ACL_FORMAT_FRACTAL_ND
) and not self.torchair_graph_enabled else ACL_FORMAT_FRACTAL_ND
kv_caches: Dict[str, torch.Tensor] = {}
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
@@ -1935,6 +2014,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
alignment = 2 * 1024 * 1024
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, FullAttentionSpec):
@@ -1949,58 +2029,78 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
if self.torchair_graph_enabled:
if len(kv_cache_shape) == 3:
# for non MLA attention backend that use torchair, we consider to pass kv_cache layout
# of BSH ([num_blocks, block_size, kv_head_dim * head_size]) to attention.
dtype = kv_cache_spec.dtype
if self.model_config.is_deepseek_mla:
kv_caches[layer_name] = (
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device),
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
# atb reshape_and_cache does not support torchair.
kv_caches[layer_name] = (
torch_npu.npu_format_cast(
kv_caches[layer_name][0],
ACL_FORMAT_FRACTAL_ND),
torch_npu.npu_format_cast(
kv_caches[layer_name][1],
ACL_FORMAT_FRACTAL_ND),
)
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
nope_dim = head_size - rope_dim
nope_cache_shape = (num_blocks, block_size,
num_kv_heads, nope_dim)
rope_cache_shape = (num_blocks, block_size,
num_kv_heads, rope_dim)
if self.vllm_config.kv_transfer_config is None:
# For no disaggregate pd scenario, allocate kv cache in normal way
rope_cache = torch.zeros(rope_cache_shape,
dtype=dtype,
device=self.device)
nope_cache = torch.zeros(nope_cache_shape,
dtype=dtype,
device=self.device)
rope_cache = torch_npu.npu_format_cast(
rope_cache, acl_format)
nope_cache = torch_npu.npu_format_cast(
nope_cache, acl_format)
else:
# for MLA attention backend that use torchair.
layer_kv_cache_nope = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank,
),
dtype=self.dtype,
pin_memory=True,
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
# we found there are also some exceptions during test, so we manual align those memory here, this part
# of code may consume 2M * 2 * elem_size memory every layer.
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
nope_allocate_shape_alignment = nope_allocate_shape + alignment
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
rope_allocate_shape_alignment = rope_allocate_shape + alignment
nope_cache = torch.zeros(
nope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
layer_kv_cache_pe = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.
qk_rope_head_dim, ),
dtype=self.dtype,
pin_memory=True,
rope_cache = torch.zeros(
rope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
kv_caches[layer_name] = (layer_kv_cache_nope,
layer_kv_cache_pe)
kv_caches[layer_name] = (
torch_npu.npu_format_cast(
kv_caches[layer_name][0], acl_format),
torch_npu.npu_format_cast(
kv_caches[layer_name][1], acl_format),
)
nope_cache = align_memory(
nope_cache,
alignment)[:nope_allocate_shape].view(
nope_cache_shape)
rope_cache = align_memory(
rope_cache,
alignment)[:rope_allocate_shape].view(
rope_cache_shape)
kv_caches[layer_name] = (nope_cache, rope_cache)
else:
kv_caches[layer_name] = torch.zeros(
kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device)
kv_caches[layer_name] = \
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
num_caches = kv_cache_shape[0]
kv_cache_list = []
for i in range(num_caches):
cache_shape = kv_cache_shape[1:]
if self.vllm_config.kv_transfer_config is None:
kv_cache = torch.zeros(cache_shape,
dtype=dtype,
device=self.device)
kv_cache = torch_npu.npu_format_cast(
kv_cache, acl_format)
else:
cache_size = math.prod(cache_shape)
cache_size_aligned = cache_size + alignment
kv_cache = torch.zeros(cache_size_aligned,
dtype=dtype,
device=self.device)
kv_cache = align_memory(
kv_cache,
alignment)[:cache_size].view(cache_shape)
kv_cache_list.append(kv_cache)
kv_caches[layer_name] = tuple(kv_cache_list)
else:
# TODO: add new branches when introducing more types of
# KV cache specs.
@@ -2011,6 +2111,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each