update
This commit is contained in:
134
vllm/v1/worker/gpu/kv_connector.py
Normal file
134
vllm/v1/worker/gpu/kv_connector.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
kv_transfer_state,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.forward_context import (
|
||||
get_forward_context,
|
||||
is_forward_context_available,
|
||||
set_forward_context,
|
||||
)
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
KVConnectorOutput,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class KVConnector:
|
||||
"""KVConnector interface used by GPUModelRunner."""
|
||||
|
||||
def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
pass
|
||||
|
||||
def post_forward(
|
||||
self, scheduler_output: "SchedulerOutput", wait_for_save: bool = True
|
||||
) -> KVConnectorOutput | None:
|
||||
return None
|
||||
|
||||
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
def set_disabled(self, disabled: bool) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ActiveKVConnector(KVConnector):
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.kv_connector = get_kv_transfer_group()
|
||||
# Register kv caches with KV Connector if applicable.
|
||||
# TODO: support cross_layers_kv_cache
|
||||
# (see https://github.com/vllm-project/vllm/pull/27743)
|
||||
self.kv_connector.register_kv_caches(kv_caches_dict)
|
||||
self.kv_connector.set_host_xfer_buffer_ops(copy_kv_blocks)
|
||||
|
||||
self._disabled = False
|
||||
|
||||
def pre_forward(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
if self._disabled:
|
||||
return
|
||||
|
||||
if scheduler_output.preempted_req_ids:
|
||||
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
self.kv_connector.bind_connector_metadata(kv_connector_metadata)
|
||||
|
||||
# TODO: sort out KV Connectors' use of forward_context
|
||||
if is_forward_context_available():
|
||||
self.kv_connector.start_load_kv(get_forward_context())
|
||||
else:
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
self.kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
def post_forward(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
wait_for_save: bool = True,
|
||||
clear_metadata: bool = True,
|
||||
) -> KVConnectorOutput | None:
|
||||
if self._disabled:
|
||||
return None
|
||||
|
||||
output = KVConnectorOutput()
|
||||
if wait_for_save:
|
||||
self.kv_connector.wait_for_save()
|
||||
output.finished_sending, output.finished_recving = (
|
||||
self.kv_connector.get_finished(scheduler_output.finished_req_ids)
|
||||
)
|
||||
output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors()
|
||||
output.kv_connector_stats = self.kv_connector.get_kv_connector_stats()
|
||||
output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events()
|
||||
if clear_metadata:
|
||||
self.kv_connector.clear_connector_metadata()
|
||||
return output
|
||||
|
||||
def clear_metadata(self) -> None:
|
||||
"""Clear the connector metadata. Call this after draft model runs."""
|
||||
if not self._disabled:
|
||||
self.kv_connector.clear_connector_metadata()
|
||||
|
||||
def no_forward(self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||
if self._disabled:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
self.pre_forward(scheduler_output)
|
||||
kv_connector_output = self.post_forward(scheduler_output, wait_for_save=False)
|
||||
if kv_connector_output is None or kv_connector_output.is_empty():
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
def set_disabled(self, disabled: bool) -> None:
|
||||
# Ensure that layer-wise connector hooks aren't called when disabled.
|
||||
kv_transfer_state._KV_CONNECTOR_AGENT = None if disabled else self.kv_connector
|
||||
self._disabled = disabled
|
||||
|
||||
|
||||
NO_OP_KV_CONNECTOR = KVConnector()
|
||||
|
||||
|
||||
def get_kv_connector(
|
||||
vllm_config: VllmConfig, kv_caches_dict: dict[str, torch.Tensor]
|
||||
) -> KVConnector:
|
||||
if not has_kv_transfer_group():
|
||||
# No-op connector.
|
||||
return NO_OP_KV_CONNECTOR
|
||||
|
||||
return ActiveKVConnector(vllm_config, kv_caches_dict)
|
||||
Reference in New Issue
Block a user