[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
6
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
6
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorRole)
|
||||
|
||||
__all__ = ["KVConnectorRole", "KVConnectorBase_V1"]
|
||||
283
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
283
vllm/distributed/kv_transfer/kv_connector/v1/base.py
Normal file
@@ -0,0 +1,283 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
|
||||
communication in vLLM v1
|
||||
|
||||
The class provides the following primitives:
|
||||
Scheduler-side: runs in the scheduler, binds metadata, which
|
||||
is used by the worker-side to load/save KV cache.
|
||||
get_num_new_matched_tokens() - get number of new tokens
|
||||
that exist in the remote KV cache. Might be called multiple
|
||||
times for a given request and should be side-effect free.
|
||||
update_state_after_alloc() - update KVConnector state after
|
||||
temporary buffer alloc by the CacheManager.
|
||||
request_finished() - called when a request is finished, with
|
||||
the computed kv cache blocks for the request.
|
||||
Returns whether KV cache should be freed now or will be
|
||||
freed asynchronously and optionally returns KV transfer
|
||||
params.
|
||||
|
||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||
the Connector based on the metadata.
|
||||
start_load_kv() - starts loading all KVs (maybe async)
|
||||
wait_for_layer_load() - blocks until layer i load is done
|
||||
|
||||
save_kv_layer() - starts saving KV for layer i (maybe async)
|
||||
wait_for_save() - blocks until all saves are done
|
||||
|
||||
get_finished() - called with ids of finished requests, returns
|
||||
ids of requests that have completed async sending/recving.
|
||||
"""
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVConnectorRole(enum.Enum):
|
||||
# Connector running in the scheduler process
|
||||
SCHEDULER = 0
|
||||
|
||||
# Connector running in the worker process
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class KVConnectorMetadata:
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
Scheduler KVConnector and Worker KVConnector.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorBase_V1(ABC):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
logger.warning(
|
||||
"Initializing KVConnectorBase_V1. This API is experimental and "
|
||||
"subject to change in the future as we iterate the design.")
|
||||
self._connector_metadata = KVConnectorMetadata()
|
||||
self._vllm_config = vllm_config
|
||||
self._role = role
|
||||
|
||||
@property
|
||||
def role(self) -> KVConnectorRole:
|
||||
return self._role
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
before the model execution. The metadata will be used for runtime
|
||||
KV cache loading and saving.
|
||||
|
||||
Args:
|
||||
connector_metadata (dict): the connector metadata.
|
||||
"""
|
||||
self._connector_metadata = connector_metadata
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
"""Clear the connector metadata.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
after the model execution.
|
||||
"""
|
||||
self._connector_metadata = KVConnectorMetadata()
|
||||
|
||||
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
||||
"""Get the connector metadata.
|
||||
|
||||
This function should only be called inside the connector.
|
||||
|
||||
Returns:
|
||||
ConnectorMetadata: the connector metadata.
|
||||
"""
|
||||
return self._connector_metadata
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args: kv_caches:
|
||||
dictionary of layer names, kv cache
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""
|
||||
Start saving a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return None, None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@abstractmethod
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- The number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
- `True` if external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps). Must be
|
||||
'False' if the first element is 0.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If get_num_new_matched_tokens previously returned True for a
|
||||
request, this function may be called twice for that same request -
|
||||
first when blocks are allocated for the connector tokens to be
|
||||
asynchronously loaded into, and second when any additional blocks
|
||||
are allocated, after the load/transfer is complete.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
blocks (KVCacheBlocks): the blocks allocated for the request.
|
||||
num_external_tokens (int): the number of tokens that will be
|
||||
loaded from the external KV cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return False, None
|
||||
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
self._lmcache_engine.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""
|
||||
Start saving the a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
|
||||
**kwargs)
|
||||
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
self._lmcache_engine.wait_for_save()
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
return self._lmcache_engine.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens), False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
self._lmcache_engine.update_state_after_alloc(request,
|
||||
num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
return self._lmcache_engine.build_connector_meta(scheduler_output)
|
||||
201
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
201
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorMetadata(KVConnectorMetadata):
|
||||
metadata: tuple[KVConnectorMetadata, ...]
|
||||
extra_async_saves: Optional[dict[str, int]] = None
|
||||
|
||||
|
||||
class MultiConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A wrapper for using multiple KVConnectors at the same time.
|
||||
|
||||
The current logic is:
|
||||
- Load KV from the first connector that advertises available tokens from
|
||||
get_num_new_matched_tokens(), based on the order in the config.
|
||||
- Save to all connectors.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._connectors: list[KVConnectorBase_V1] = []
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors")
|
||||
assert ktcs is not None
|
||||
for ktc in ktcs:
|
||||
temp_config = copy.copy(vllm_config)
|
||||
temp_config.kv_transfer_config = KVTransferConfig(**ktc)
|
||||
self._connectors.append(
|
||||
KVConnectorFactory.create_connector_v1(temp_config, role))
|
||||
|
||||
# A mapping from request id to the index of the connector chosen to
|
||||
# load the request from (if any).
|
||||
self._requests_to_connector: dict[str, int] = {}
|
||||
|
||||
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
||||
# finished per request. Not needed for async loads since we only allow
|
||||
# a single connector to load.
|
||||
# Propagated from scheduler to worker side via the connector metadata.
|
||||
self._extra_async_saves: dict[str, int] = {}
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for c in self._connectors:
|
||||
c.register_kv_caches(kv_caches)
|
||||
|
||||
# We must override the base class method here because we need to bind
|
||||
# the metadata to each connector in the order of the connectors in the
|
||||
# MultiKVConnectorMetadata.
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
||||
if connector_metadata.extra_async_saves:
|
||||
self._extra_async_saves.update(
|
||||
connector_metadata.extra_async_saves)
|
||||
for c, cm in zip(self._connectors, connector_metadata.metadata):
|
||||
c.bind_connector_metadata(cm)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
for c in self._connectors:
|
||||
c.clear_connector_metadata()
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
for c in self._connectors:
|
||||
c.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
for c in self._connectors:
|
||||
c.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
for c in self._connectors:
|
||||
c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
|
||||
|
||||
def wait_for_save(self):
|
||||
for c in self._connectors:
|
||||
c.wait_for_save()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
finished_sending: set[str] = set()
|
||||
finished_recving: set[str] = set()
|
||||
for c in self._connectors:
|
||||
sending, recving = c.get_finished(finished_req_ids)
|
||||
if not recving and not sending:
|
||||
continue
|
||||
# Aggregate finished recving request ids.
|
||||
finished_recving.update(recving or ())
|
||||
# Aggregate finished sending request ids - only include
|
||||
# once we've drained the "extra" count (for cases where
|
||||
# more than one connector is async-saving the same request).
|
||||
for req_id in sending or ():
|
||||
extra_pending = self._extra_async_saves.get(req_id)
|
||||
if extra_pending is None:
|
||||
finished_sending.add(req_id)
|
||||
continue
|
||||
assert extra_pending > 0
|
||||
if extra_pending == 1:
|
||||
del self._extra_async_saves[req_id]
|
||||
else:
|
||||
self._extra_async_saves[req_id] = extra_pending - 1
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
to_return = (0, False)
|
||||
for i, c in enumerate(self._connectors):
|
||||
toks, load_async = c.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
# The first connector that has new matched tokens will be assigned
|
||||
# to this request.
|
||||
if to_return[0] == 0 and toks > 0:
|
||||
self._requests_to_connector[request.request_id] = i
|
||||
to_return = (toks, load_async)
|
||||
return to_return
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
chosen_connector = self._requests_to_connector.get(
|
||||
request.request_id, -1)
|
||||
empty_blocks = blocks.new_empty()
|
||||
for i, c in enumerate(self._connectors):
|
||||
if i == chosen_connector:
|
||||
# Forward call to the chosen connector (if any).
|
||||
c.update_state_after_alloc(request, blocks,
|
||||
num_external_tokens)
|
||||
else:
|
||||
# Call with empty blocks for other connectors.
|
||||
c.update_state_after_alloc(request, empty_blocks, 0)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
|
||||
metadata = MultiKVConnectorMetadata(metadata=tuple(
|
||||
c.build_connector_meta(scheduler_output)
|
||||
for c in self._connectors))
|
||||
if self._extra_async_saves:
|
||||
metadata.extra_async_saves = self._extra_async_saves
|
||||
self._extra_async_saves = {}
|
||||
return metadata
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
async_saves = 0
|
||||
kv_txfer_params = None
|
||||
for c in self._connectors:
|
||||
async_save, txfer_params = c.request_finished(request, blocks)
|
||||
if async_save:
|
||||
async_saves += 1
|
||||
if txfer_params is not None:
|
||||
if kv_txfer_params is not None:
|
||||
#TODO we can probably change this to merge the dicts here,
|
||||
# checking for key clashes.
|
||||
raise RuntimeError(
|
||||
"Only one connector can produce KV transfer params")
|
||||
kv_txfer_params = txfer_params
|
||||
if async_saves > 1:
|
||||
self._extra_async_saves[request.request_id] = async_saves - 1
|
||||
|
||||
# Clean up other state for this request.
|
||||
self._requests_to_connector.pop(request.request_id, None)
|
||||
|
||||
return async_saves > 0, kv_txfer_params
|
||||
1030
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
1030
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,384 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Is store or load
|
||||
is_store: bool
|
||||
|
||||
@staticmethod
|
||||
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
|
||||
is_store: bool) -> "ReqMeta":
|
||||
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
|
||||
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = block_offsets.reshape((1, block_size)) + \
|
||||
block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
|
||||
return ReqMeta(
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
is_store=is_store,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SharedStorageConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.requests = []
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
is_store: bool,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store))
|
||||
|
||||
|
||||
class SharedStorageConnector(KVConnectorBase_V1):
|
||||
# NOTE: This is Simple debug implementation of the KV connector.
|
||||
# It save / load the KV cache to / from the disk.
|
||||
# It does extra work which will overwrite the existing prefix-cache in GPU
|
||||
# - to remove the overhead, need to add some "mask" in the ReqMeta class
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Request] = {}
|
||||
transfer_config = vllm_config.kv_transfer_config
|
||||
self._storage_path = transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp")
|
||||
logger.info(vllm_config.kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
|
||||
Args:
|
||||
dst_kv_cache_layer (torch.Tensor): the destination KV cache
|
||||
layer. In shape [2, num_pages, page_size, xxx] if not
|
||||
using MLA, [num_pages, page_size, xxx] otherwise.
|
||||
src_kv_cache (torch.Tensor): the source KV cache. In shape
|
||||
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
|
||||
otherwise.
|
||||
slot_mapping (torch.Tensor): the slot mapping. In shape
|
||||
[num_tokens].
|
||||
"""
|
||||
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages = dst_kv_cache_layer_shape[0]
|
||||
page_size = dst_kv_cache_layer_shape[1]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
num_pages * page_size, -1)
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
2, num_pages * page_size, -1)
|
||||
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = self._get_connector_metadata()
|
||||
assert isinstance(metadata, SharedStorageConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
logger.warning(
|
||||
"In connector.start_load_kv, but the connector metadata is None"
|
||||
)
|
||||
return
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
logger.warning(
|
||||
"In connector.start_load_kv, but the attn_metadata is None")
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
if request.is_store:
|
||||
continue
|
||||
logger.info("Inject KV cache of %d tokens to the paged memory",
|
||||
len(request.slot_mapping))
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache_layer = attn_layer.kv_cache[\
|
||||
forward_context.virtual_engine]
|
||||
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids)
|
||||
kv_cache = safetensors.torch.load_file(
|
||||
filename)["kv_cache"].cuda()
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache,
|
||||
request.slot_mapping)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the KV cache from the layer.
|
||||
|
||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||
"""
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
|
||||
...]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
|
||||
...]
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
if request.is_store:
|
||||
filename = self._generate_filename_debug(
|
||||
layer_name, request.token_ids)
|
||||
kv_cache = extract_kv_from_layer(kv_layer,
|
||||
request.slot_mapping)
|
||||
tensors = {"kv_cache": kv_cache.detach().cpu()}
|
||||
safetensors.torch.save_file(tensors, filename)
|
||||
|
||||
def wait_for_save(self):
|
||||
return
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
# NOTE: in this debug implementation, we assume that the prompt is
|
||||
# cached_prompt + newly_generated_single_token
|
||||
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
|
||||
|
||||
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
|
||||
# with the block granularity. And it expects the returned blocks and
|
||||
# num_computed_tokens to also be aligned with the block granularity.
|
||||
if not self._found_match_for_request(request):
|
||||
return 0, False
|
||||
|
||||
logger.info("External Cache Hit!")
|
||||
|
||||
# Now, first num_tokens_to_check tokens are hit, we need to prepare
|
||||
# the metadata for the worker connector to correctly load the KV
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(request.prompt_token_ids) - 1, self._block_size)
|
||||
|
||||
return num_tokens_to_check - num_computed_tokens, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If blocks were allocated, add to _requests_need_load,
|
||||
such that we load the KVs in the next forward pass.
|
||||
"""
|
||||
if num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = request
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = SharedStorageConnectorMetadata()
|
||||
|
||||
total_need_load = 0
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=False)
|
||||
total_need_load += 1
|
||||
else:
|
||||
# NOTE: here, we set the store and load being exclusive,
|
||||
# but a single request can have both store and load.
|
||||
# NOTE(rob): for this debug implementation, we only cache
|
||||
# the original prompt tokens.
|
||||
if not self._found_match_for_request(new_req):
|
||||
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=True)
|
||||
|
||||
for cached_req in scheduler_output.scheduled_cached_reqs:
|
||||
# NOTE(rob): here we rely on the resumed requests being
|
||||
# the first N requests in the list scheduled_cache_reqs.
|
||||
if not cached_req.resumed_from_preemption:
|
||||
break
|
||||
if cached_req.req_id in self._requests_need_load:
|
||||
# NOTE(rob): cached_req_data does not have the full
|
||||
# list of token ids (only new tokens). So we look it
|
||||
# up in the actual request object.
|
||||
request = self._requests_need_load[cached_req.req_id]
|
||||
total_tokens = (len(cached_req.new_token_ids) +
|
||||
cached_req.num_computed_tokens)
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
block_ids = cached_req.new_block_ids[0]
|
||||
|
||||
meta.add_request(token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
is_store=False)
|
||||
total_need_load += 1
|
||||
|
||||
assert total_need_load == len(self._requests_need_load)
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
# ==============================
|
||||
# Helper functions
|
||||
# ==============================
|
||||
|
||||
def _found_match_for_request(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> bool:
|
||||
"""Check if the cache is hit for the request.
|
||||
"""
|
||||
num_tokens_to_check = align_to_block_size(
|
||||
len(request.prompt_token_ids) - 1, self._block_size)
|
||||
foldername = self._generate_foldername_debug(torch.tensor(
|
||||
request.prompt_token_ids)[:num_tokens_to_check],
|
||||
create_folder=False)
|
||||
return os.path.exists(foldername)
|
||||
|
||||
def _generate_foldername_debug(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
create_folder=False,
|
||||
) -> str:
|
||||
"""Generate a folder name based on the hash of the bytes of the input
|
||||
ids.
|
||||
"""
|
||||
input_ids_bytes = input_ids.numpy().tobytes()
|
||||
input_ids_hash = hashlib.md5(input_ids_bytes,
|
||||
usedforsecurity=False).hexdigest()
|
||||
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||
if create_folder:
|
||||
os.makedirs(foldername, exist_ok=True)
|
||||
return foldername
|
||||
|
||||
def _generate_filename_debug(
|
||||
self,
|
||||
layer_name: str,
|
||||
input_ids: torch.Tensor,
|
||||
) -> str:
|
||||
"""Generate a file name based on the layer name and the hash
|
||||
of the bytes of the input ids.
|
||||
"""
|
||||
foldername = self._generate_foldername_debug(input_ids,
|
||||
create_folder=True)
|
||||
return os.path.join(foldername, f"{layer_name}.safetensors")
|
||||
|
||||
|
||||
def align_to_block_size(num_tokens: int, block_size) -> int:
|
||||
"""Align the number of tokens to the block size.
|
||||
"""
|
||||
return (num_tokens - 1) // block_size * block_size
|
||||
Reference in New Issue
Block a user