Disaggregate prefill for kv cache register style (#950)
### What this PR does / why we need it?
This PR adopt `LLMDataDist` for kv cache register and `pull_blocks`
style disaggregate prefill implementation. The interface implementation
mainly follows the design of NIXL PR
https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953
.
This PR can be test with the following step:
- Generate the rank table for all machine.
- execute`toy_proxy.py` to launch the disaggregate prefill proxy server,
specify the prefill ip, port and the decode ip, port
- Run the prefill server and decode server.
- send the request to the disaggregate prefill proxy
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2
---------
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Signed-off-by: liziyu179 <3475441767@qq.com>
Signed-off-by: underfitc <hucong24@huawei.com>
Signed-off-by: zouyida2052 <zouyida@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: underfituu <hzhucong@163.com>
Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Co-authored-by: liziyu179 <3475441767@qq.com>
Co-authored-by: underfitc <hucong24@huawei.com>
Co-authored-by: zouyida2052 <zouyida@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: underfituu <hzhucong@163.com>
This commit is contained in:
@@ -17,7 +17,9 @@
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
#
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
@@ -37,9 +39,12 @@ from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@@ -342,6 +347,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||
|
||||
# kv role
|
||||
self.is_kv_producer = False
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
@@ -908,7 +918,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata,
|
||||
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]:
|
||||
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray,
|
||||
Optional[set[str]], Optional[set[str]]]:
|
||||
# Check input valid
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
@@ -1144,6 +1155,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
model_kwargs = {}
|
||||
if self.torchair_graph_enabled:
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
@@ -1174,6 +1186,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
||||
scheduler_output)
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if not use_spec_decode:
|
||||
@@ -1203,7 +1218,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
|
||||
num_scheduled_tokens)
|
||||
num_scheduled_tokens, finished_sending, finished_recving)
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
@@ -1436,12 +1451,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
"prepare input and forward"):
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
# Return empty ModelRunnerOuptut if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if not has_kv_transfer_group():
|
||||
logger.debug(
|
||||
"skip this step for we receive the data from remote disaggregate prefill node"
|
||||
)
|
||||
# Return empty ModelRunnerOuptut if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
(attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
num_scheduled_tokens, logits_indices, aux_hidden_states,
|
||||
num_scheduled_tokens_np) = (self._process_reqs(
|
||||
scheduler_output, intermediate_tensors))
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving) = (self._process_reqs(scheduler_output,
|
||||
intermediate_tensors))
|
||||
|
||||
with ProfileExecuteDuration().capture_async("post process"):
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
@@ -1593,6 +1614,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states,
|
||||
)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().clear_connector_metadata()
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
@@ -1601,6 +1625,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
|
||||
durations = ProfileExecuteDuration().pop_captured_sync()
|
||||
@@ -1615,6 +1641,49 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return model_runner_output
|
||||
|
||||
def kv_connector_no_forward(
|
||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfer(scheduler_output))
|
||||
# For the case of no forward caused by receiving remote kv,
|
||||
# one round of dummy inference is necessary
|
||||
# to prevent hang over the collective calls.
|
||||
if not finished_sending and not finished_recving:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.finished_sending = finished_sending
|
||||
output.finished_recving = finished_recving
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
|
||||
# Update KVConnector with the KVConnector metadata forward().
|
||||
if has_kv_transfer_group():
|
||||
kv_connector = get_kv_transfer_group()
|
||||
assert isinstance(kv_connector, KVConnectorBase_V1)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata)
|
||||
|
||||
kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
@staticmethod
|
||||
def maybe_wait_for_kv_save() -> None:
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().wait_for_save()
|
||||
|
||||
@staticmethod
|
||||
def get_finished_kv_transfer(
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_finished(
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self,
|
||||
@@ -1633,6 +1702,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
|
||||
# Force dummy run on prefill stage when this node is deemed as kv producer.
|
||||
if self.is_kv_producer:
|
||||
with_prefill = True
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
model = self.model
|
||||
@@ -1899,9 +1972,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
import torch_npu
|
||||
acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p(
|
||||
) else ACL_FORMAT_FRACTAL_ND
|
||||
) and not self.torchair_graph_enabled else ACL_FORMAT_FRACTAL_ND
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||||
data_ptr = tensor.data_ptr()
|
||||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||
return tensor[int(offset):]
|
||||
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.model_config.max_model_len,
|
||||
@@ -1935,6 +2014,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
alignment = 2 * 1024 * 1024
|
||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||
# encounter OOM issue
|
||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
@@ -1949,58 +2029,78 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
if self.torchair_graph_enabled:
|
||||
if len(kv_cache_shape) == 3:
|
||||
# for non MLA attention backend that use torchair, we consider to pass kv_cache layout
|
||||
# of BSH ([num_blocks, block_size, kv_head_dim * head_size]) to attention.
|
||||
dtype = kv_cache_spec.dtype
|
||||
if self.model_config.is_deepseek_mla:
|
||||
|
||||
kv_caches[layer_name] = (
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device),
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device))
|
||||
# atb reshape_and_cache does not support torchair.
|
||||
kv_caches[layer_name] = (
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][0],
|
||||
ACL_FORMAT_FRACTAL_ND),
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][1],
|
||||
ACL_FORMAT_FRACTAL_ND),
|
||||
)
|
||||
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
nope_dim = head_size - rope_dim
|
||||
nope_cache_shape = (num_blocks, block_size,
|
||||
num_kv_heads, nope_dim)
|
||||
rope_cache_shape = (num_blocks, block_size,
|
||||
num_kv_heads, rope_dim)
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
# For no disaggregate pd scenario, allocate kv cache in normal way
|
||||
rope_cache = torch.zeros(rope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
nope_cache = torch.zeros(nope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = torch_npu.npu_format_cast(
|
||||
rope_cache, acl_format)
|
||||
nope_cache = torch_npu.npu_format_cast(
|
||||
nope_cache, acl_format)
|
||||
else:
|
||||
# for MLA attention backend that use torchair.
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
|
||||
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
||||
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
||||
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
||||
# of code may consume 2M * 2 * elem_size memory every layer.
|
||||
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
||||
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
||||
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
||||
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
||||
|
||||
nope_cache = torch.zeros(
|
||||
nope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.
|
||||
qk_rope_head_dim, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
rope_cache = torch.zeros(
|
||||
rope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = (layer_kv_cache_nope,
|
||||
layer_kv_cache_pe)
|
||||
kv_caches[layer_name] = (
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][0], acl_format),
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][1], acl_format),
|
||||
)
|
||||
nope_cache = align_memory(
|
||||
nope_cache,
|
||||
alignment)[:nope_allocate_shape].view(
|
||||
nope_cache_shape)
|
||||
rope_cache = align_memory(
|
||||
rope_cache,
|
||||
alignment)[:rope_allocate_shape].view(
|
||||
rope_cache_shape)
|
||||
kv_caches[layer_name] = (nope_cache, rope_cache)
|
||||
else:
|
||||
kv_caches[layer_name] = torch.zeros(
|
||||
kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = \
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
|
||||
num_caches = kv_cache_shape[0]
|
||||
kv_cache_list = []
|
||||
for i in range(num_caches):
|
||||
cache_shape = kv_cache_shape[1:]
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
kv_cache = torch.zeros(cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
kv_cache = torch_npu.npu_format_cast(
|
||||
kv_cache, acl_format)
|
||||
else:
|
||||
cache_size = math.prod(cache_shape)
|
||||
cache_size_aligned = cache_size + alignment
|
||||
kv_cache = torch.zeros(cache_size_aligned,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
kv_cache = align_memory(
|
||||
kv_cache,
|
||||
alignment)[:cache_size].view(cache_shape)
|
||||
kv_cache_list.append(kv_cache)
|
||||
kv_caches[layer_name] = tuple(kv_cache_list)
|
||||
else:
|
||||
# TODO: add new branches when introducing more types of
|
||||
# KV cache specs.
|
||||
@@ -2011,6 +2111,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
|
||||
Reference in New Issue
Block a user