Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -149,6 +149,12 @@ KVConnectorFactory.register_connector(
|
||||
"ExampleConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"ExampleHiddenStatesConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.example_hidden_states_connector",
|
||||
"ExampleHiddenStatesConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"P2pNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
|
||||
|
||||
@@ -413,7 +413,20 @@ class TpKVTopology:
|
||||
f"by local tensor parallel size {self.tp_size}."
|
||||
)
|
||||
# P TP > D TP case, return the ratio as negative
|
||||
return -remote_tp_size // self.tp_size
|
||||
return remote_tp_size // self.tp_size
|
||||
|
||||
def pp_ratio(
|
||||
self,
|
||||
remote_pp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the pipeline parallel ratio between local and remote PP.
|
||||
"""
|
||||
assert self.pp_size % remote_pp_size == 0 or remote_pp_size % self.pp_size == 0, (
|
||||
f"Local pipline parallel size {self.tp_size} is not divisible "
|
||||
f"by remote pipline parallel size {remote_pp_size} or vice versa."
|
||||
)
|
||||
return self.pp_size // remote_pp_size if self.pp_size % remote_pp_size == 0 else remote_pp_size // self.pp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
@@ -457,6 +470,7 @@ class TpKVTopology:
|
||||
def get_target_remote_ranks(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
remote_pp_size: int
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
@@ -464,19 +478,36 @@ class TpKVTopology:
|
||||
read from multiple remote ranks.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
if tp_ratio > 0:
|
||||
return [self.tp_rank // tp_ratio]
|
||||
pp_ratio = self.pp_ratio(remote_pp_size)
|
||||
target_pp_rank_list = []
|
||||
target_tp_rank_list = []
|
||||
if self.pp_size < remote_pp_size:
|
||||
for i in range(pp_ratio):
|
||||
target_pp_rank_list.append(self.pp_rank * pp_ratio + i)
|
||||
else:
|
||||
target_pp_rank_list.append(self.pp_rank // pp_ratio)
|
||||
|
||||
# P TP > D TP case, D reads from |tp_ratio| remote workers.
|
||||
tp_ratio = -tp_ratio
|
||||
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
|
||||
if self.tp_size < remote_tp_size:
|
||||
for i in range(tp_ratio):
|
||||
target_tp_rank_list.append(self.tp_rank * tp_ratio + i)
|
||||
else:
|
||||
target_tp_rank_list.append(self.tp_rank // tp_ratio)
|
||||
|
||||
target_rank_list = []
|
||||
for pp_rank in target_pp_rank_list:
|
||||
for tp_rank in target_tp_rank_list:
|
||||
target_rank = pp_rank * remote_tp_size + tp_rank
|
||||
target_rank_list.append((target_rank, pp_rank, tp_rank))
|
||||
|
||||
return target_rank_list
|
||||
|
||||
def get_target_remote_ranks_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> list[int]:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_ranks(remote_tp_size)
|
||||
remote_pp_size = self.remote_pp_size[remote_engine_id]
|
||||
return self.get_target_remote_ranks(remote_tp_size, remote_pp_size)
|
||||
|
||||
|
||||
def get_current_attn_backend(vllm_config: VllmConfig):
|
||||
|
||||
@@ -543,6 +543,28 @@ class KVConnectorBase_V1(ABC):
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if this connector requires PIECEWISE CUDA graph mode.
|
||||
|
||||
Connectors that use asynchronous layer-by-layer operations
|
||||
(wait_for_layer_load/save_kv_layer) should override this method
|
||||
to return True when those operations are enabled. These operations
|
||||
cannot be captured in CUDA graphs and will be skipped during replay,
|
||||
causing data races. PIECEWISE mode allows Python code to execute
|
||||
between graph pieces, ensuring proper synchronization.
|
||||
|
||||
Args:
|
||||
extra_config: The kv_connector_extra_config dict from
|
||||
KVTransferConfig.
|
||||
|
||||
Returns:
|
||||
True if this connector requires PIECEWISE CUDA graph mode,
|
||||
False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -118,12 +119,12 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
|
||||
@@ -145,6 +146,10 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
elif isinstance(attn_metadata, TritonAttentionMetadata):
|
||||
block_idxs = slot_mapping // self._block_size
|
||||
offsets = slot_mapping % self._block_size
|
||||
dst_kv_cache_layer[block_idxs, :, offsets] = src_kv_cache
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
@@ -186,7 +191,13 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
|
||||
if isinstance(attn_metadata, dict):
|
||||
inject_kv_into_layer(
|
||||
kv_cache_layer,
|
||||
kv_cache,
|
||||
request.slot_mapping,
|
||||
attn_metadata[layer_name],
|
||||
)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
@@ -229,6 +240,10 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
|
||||
elif isinstance(attn_metadata, TritonAttentionMetadata):
|
||||
block_idxs = slot_mapping // self._block_size
|
||||
offsets = slot_mapping % self._block_size
|
||||
return layer[block_idxs, :, offsets]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
|
||||
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def extract_from_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
"""Extract data from KV cache
|
||||
Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size)
|
||||
"""
|
||||
|
||||
padded_kv = kv_cache.flatten(0, 1)[slot_mapping]
|
||||
# shape: [len(slot_mapping), num_heads, head_size]
|
||||
return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request ID
|
||||
req_id: str
|
||||
# Request filename
|
||||
filename: str
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Whether this request is a new request or partially computed already
|
||||
new_req: bool
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
req_id: str,
|
||||
filename: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
new_req: bool,
|
||||
) -> "ReqMeta":
|
||||
token_ids_tensor = torch.tensor(token_ids)
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = (
|
||||
block_offsets.reshape((1, block_size))
|
||||
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
return ReqMeta(
|
||||
req_id=req_id,
|
||||
filename=filename,
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
new_req=new_req,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleHiddenStatesConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta] = field(default_factory=list)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
filename: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
new_req: bool = True,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(
|
||||
req_id, filename, token_ids, block_ids, block_size, new_req
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ExampleHiddenStatesConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
Simple debug implementation of a HiddenStatesConnector.
|
||||
|
||||
Simply extracts the hidden states from the kv cache and stores them to disk.
|
||||
Must be used in conjunction with the `extract_hidden_states` spec decoding method.
|
||||
"""
|
||||
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
"""
|
||||
Indicates whether this connector prefers KV blocks that hold KV data for all
|
||||
layers, which can speed up KV data transfers. Defaults to False.
|
||||
"""
|
||||
# Must be False so that drafter kv cache isn't merged with verifier's
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp"
|
||||
)
|
||||
self.cache_layers: list[str] = [] # set by self.register_kv_caches
|
||||
logger.info(self._kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
assert self._vllm_config.speculative_config is not None, (
|
||||
"ExampleHiddenStatesConnector only works when using "
|
||||
"'extract_hidden_states' speculative method"
|
||||
)
|
||||
spec_config = self._vllm_config.speculative_config.draft_model_config.hf_config
|
||||
self.num_hidden_states = len(
|
||||
getattr(spec_config, "eagle_aux_hidden_state_layer_ids", [])
|
||||
)
|
||||
|
||||
self._request_filenames: dict[str, str] = {}
|
||||
self._active_requests: dict[str, NewRequestData] = {}
|
||||
self._req_blocks: dict[str, list[int]] = {}
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, *args, **kwargs: Any) -> None:
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def wait_for_save(self):
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
from vllm.model_executor.models.extract_hidden_states import (
|
||||
CacheOnlyAttentionLayer,
|
||||
)
|
||||
|
||||
# Filter layers to only include CacheOnlyAttentionLayers
|
||||
layers = get_layers_from_vllm_config(
|
||||
self._vllm_config, CacheOnlyAttentionLayer, list(kv_caches.keys())
|
||||
)
|
||||
self.cache_layers = list(layers.keys())
|
||||
assert len(self.cache_layers) == 1, (
|
||||
f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}"
|
||||
)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
if layer_name not in self.cache_layers:
|
||||
return
|
||||
|
||||
from vllm.model_executor.models.extract_hidden_states import (
|
||||
CacheOnlyAttentionMetadata,
|
||||
)
|
||||
|
||||
assert isinstance(attn_metadata, CacheOnlyAttentionMetadata), (
|
||||
"ExampleHiddenStatesConnector only supports CacheOnlyAttentionBackend"
|
||||
)
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata)
|
||||
|
||||
os.makedirs(self._storage_path, exist_ok=True)
|
||||
for request in connector_metadata.requests:
|
||||
hidden_states = extract_from_kv_cache(
|
||||
kv_layer, request.slot_mapping, request.token_ids.shape[0]
|
||||
)
|
||||
tensors = {
|
||||
"hidden_states": hidden_states.detach().cpu(),
|
||||
"token_ids": request.token_ids.detach().cpu(),
|
||||
}
|
||||
safetensors.torch.save_file(tensors, request.filename)
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
# This connector is store-only, so we don't need to load any tokens
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
# Usually used to handle allocation of new blocks for requests that are loading
|
||||
# tokens from connector's external kv cache. We never load from external cache
|
||||
# so this is a no-op.
|
||||
assert num_external_tokens == 0, "This connector is store-only"
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = ExampleHiddenStatesConnectorMetadata()
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
token_ids = new_req.prompt_token_ids or []
|
||||
filename = os.path.join(self._storage_path, f"{new_req.req_id}.safetensors")
|
||||
meta.add_request(
|
||||
new_req.req_id,
|
||||
filename=filename,
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
self._request_filenames[new_req.req_id] = filename
|
||||
self._active_requests[new_req.req_id] = new_req
|
||||
self._req_blocks[new_req.req_id] = list(new_req.block_ids[0])
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
if req_id not in self._active_requests:
|
||||
continue
|
||||
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
|
||||
cached_req = self._active_requests[req_id]
|
||||
req_block_ids = self._req_blocks[req_id]
|
||||
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
req_block_ids.extend(block_ids)
|
||||
filename = os.path.join(self._storage_path, f"{req_id}.safetensors")
|
||||
|
||||
meta.add_request(
|
||||
req_id=req_id,
|
||||
filename=filename,
|
||||
token_ids=cached_req.prompt_token_ids or [],
|
||||
block_ids=req_block_ids,
|
||||
block_size=self._block_size,
|
||||
new_req=False,
|
||||
)
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
req_filename = self._request_filenames.pop(req_id, None)
|
||||
_ = self._active_requests.pop(req_id, None)
|
||||
_ = self._req_blocks.pop(req_id, None)
|
||||
|
||||
return False, {"hidden_states_path": req_filename}
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
|
||||
if cls is KVConnectorBase_V1:
|
||||
raise TypeError(
|
||||
"get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class"
|
||||
)
|
||||
# NHD means we have (num_tokens, num_heads)
|
||||
# HND means we have (num_heads, num_tokens)
|
||||
# For now, we only support NHD layout since this keeps the
|
||||
# hidden states for each token together in memory.
|
||||
# HND is primarily used when sharding heads across devices.
|
||||
return "NHD"
|
||||
@@ -70,6 +70,16 @@ class LMCacheKVEvents(KVConnectorKVEvents):
|
||||
|
||||
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
@classmethod
|
||||
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
LMCache requires PIECEWISE CUDA graph mode when layerwise
|
||||
operations are enabled. The wait_for_layer_load and save_kv_layer
|
||||
methods perform actual async synchronization that cannot be
|
||||
captured in CUDA graphs.
|
||||
"""
|
||||
return extra_config.get("use_layerwise", False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
|
||||
@@ -173,6 +173,29 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
|
||||
|
||||
|
||||
############################################################
|
||||
# Class Methods
|
||||
############################################################
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
|
||||
if vllm_config.model_config is None:
|
||||
logger.warning_once(
|
||||
"Unable to detect current VLLM config. "
|
||||
"Fallback to default kv cache layout."
|
||||
)
|
||||
return None
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
if use_mla:
|
||||
# return None when we have mla
|
||||
# as the layout should not matter in that case,
|
||||
# which fallback to the default behavior.
|
||||
return None
|
||||
logger.info_once(
|
||||
"MooncakeConnector setting KV cache layout to HND for better xfer performance."
|
||||
)
|
||||
return "HND"
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
@@ -941,7 +964,13 @@ class MooncakeConnectorWorker:
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
||||
"All kv cache tensors must have the same size"
|
||||
)
|
||||
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
|
||||
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "HND":
|
||||
kernel_block_size = cache.shape[-2]
|
||||
else:
|
||||
kernel_block_size = cache.shape[-3]
|
||||
|
||||
assert self.block_size == kernel_block_size
|
||||
kv_data_ptrs.append(base_addr)
|
||||
kv_data_lens.append(tensor_size_bytes)
|
||||
|
||||
@@ -112,6 +112,21 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
- Save to all connectors.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
MultiConnector requires PIECEWISE CUDA graph mode if any of its
|
||||
child connectors require it.
|
||||
"""
|
||||
connectors_config = extra_config.get("connectors", [])
|
||||
for conn_config in connectors_config:
|
||||
temp_ktc = KVTransferConfig(**conn_config)
|
||||
connector_cls = KVConnectorFactory.get_connector_class(temp_ktc)
|
||||
child_extra_config = conn_config.get("kv_connector_extra_config", {})
|
||||
if connector_cls.requires_piecewise_for_cudagraph(child_extra_config):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user