Iluvatar-mrv100 SDK 4.3.0
This commit is contained in:
123
vllm/distributed/kv_transfer/kv_connector/base.py
Normal file
123
vllm/distributed/kv_transfer/kv_connector/base.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
|
||||
|
||||
The class provides two primary abstract methods:
|
||||
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
|
||||
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
|
||||
class KVConnectorBase(ABC):
|
||||
"""
|
||||
Abstract base class for a KV connector.
|
||||
|
||||
The class provides two primary abstract methods:
|
||||
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
|
||||
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: "VllmConfig",
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
|
||||
This method is responsible for cleaning up resources related to the
|
||||
connector when it is no longer needed.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
"""
|
||||
Send KV caches and hidden states to the connector.
|
||||
|
||||
This method processes the input tokens, KV caches, and
|
||||
hidden/intermediate states for a given model and sends the data to the
|
||||
decode instance.
|
||||
|
||||
Args:
|
||||
model_executable (torch.nn.Module): The model executable containing
|
||||
start and end layer information.
|
||||
model_input (ModelInputForGPUWithSamplingMetadata): The input
|
||||
metadata from vLLM.
|
||||
kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
|
||||
for each layer.
|
||||
hidden_or_intermediate_states (Union[torch.Tensor,
|
||||
IntermediateTensors]):
|
||||
The hidden or intermediate states associated with the tokens.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
"""
|
||||
Receive KV caches and hidden states from the connector.
|
||||
|
||||
This method attempts to retrieve KV caches and hidden states for input
|
||||
tokens. If all required KV caches and hidden states are received, it
|
||||
will bypass model input, else it will fall back to normal vLLM model
|
||||
forwarding.
|
||||
|
||||
Args:
|
||||
model_executable (torch.nn.Module):
|
||||
The model executable from vLLM modelrunner.
|
||||
model_input (ModelInputForGPUWithSamplingMetadata):
|
||||
The model input from vLLM modelrunner.
|
||||
kv_caches (List[torch.Tensor]):
|
||||
List of KV caches for each layer.
|
||||
|
||||
Returns:
|
||||
- hidden_or_intermediate_states (torch.Tensor or
|
||||
IntermediateTensors):
|
||||
Concatenated hidden states if all required data is retrieved,
|
||||
otherwise `None`.
|
||||
- bypass_model_exec (bool):
|
||||
Indicates whether the model execution can be skipped (True) or
|
||||
needs to be redone (False).
|
||||
- model_input (ModelInputForGPUWithSamplingMetadata):
|
||||
Optionally adjusted input metadata for re-execution when
|
||||
`bypass_model_exec=False`.
|
||||
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
64
vllm/distributed/kv_transfer/kv_connector/factory.py
Normal file
64
vllm/distributed/kv_transfer/kv_connector/factory.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Type
|
||||
|
||||
from .base import KVConnectorBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class KVConnectorFactory:
|
||||
_registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str,
|
||||
class_name: str) -> None:
|
||||
"""Register a connector with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> Type[KVConnectorBase]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector(cls, rank: int, local_rank: int,
|
||||
config: "VllmConfig") -> KVConnectorBase:
|
||||
connector_name = config.kv_transfer_config.kv_connector
|
||||
if connector_name not in cls._registry:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
return connector_cls(rank, local_rank, config)
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
KVConnectorFactory.register_connector(
|
||||
"P2pConnector", "vllm.distributed.kv_transfer.kv_connector.p2p_connector",
|
||||
"P2pConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"PyNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
||||
"SimpleConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
||||
"SimpleConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LMCacheConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.lmcache_connector",
|
||||
"LMCacheConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeStoreConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
|
||||
"MooncakeStoreConnector")
|
||||
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
LMCache KV Cache Connector for Distributed Machine Learning Inference
|
||||
|
||||
The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker
|
||||
(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache;
|
||||
(2) offload and share KV caches.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LMCacheConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
|
||||
self.transfer_config = config.kv_transfer_config
|
||||
self.vllm_config = config
|
||||
|
||||
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
|
||||
from lmcache.integration.vllm.utils import ENGINE_NAME
|
||||
from lmcache.integration.vllm.vllm_adapter import (
|
||||
RetrieveStatus, StoreStatus, init_lmcache_engine,
|
||||
lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store,
|
||||
lmcache_store_kv)
|
||||
logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
|
||||
self.transfer_config)
|
||||
|
||||
# TODO (Jiayi): Find model_config, parallel_config, and cache_config
|
||||
self.engine = init_lmcache_engine(config.model_config,
|
||||
config.parallel_config,
|
||||
config.cache_config)
|
||||
self.lmcache_engine_name = ENGINE_NAME
|
||||
self.lmcache_engine_builder = LMCacheEngineBuilder
|
||||
|
||||
self.model_config = config.model_config
|
||||
self.parallel_config = config.parallel_config
|
||||
self.cache_config = config.cache_config
|
||||
self.lmcache_retrieve_kv = lmcache_retrieve_kv
|
||||
self.lmcache_store_kv = lmcache_store_kv
|
||||
self.lmcache_should_retrieve = lmcache_should_retrieve
|
||||
self.lmcache_should_store = lmcache_should_store
|
||||
self.store_status = StoreStatus
|
||||
self.retrieve_status = RetrieveStatus
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
retrieve_status = self.lmcache_should_retrieve(model_input)
|
||||
model_input, bypass_model_exec, hidden_or_intermediate_states =\
|
||||
self.lmcache_retrieve_kv(
|
||||
model_executable, model_input, self.cache_config, kv_caches,
|
||||
retrieve_status)
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
store_status = self.lmcache_should_store(model_input)
|
||||
self.lmcache_store_kv(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.cache_config,
|
||||
model_executable,
|
||||
model_input,
|
||||
kv_caches,
|
||||
store_status,
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self.lmcache_engine_builder.destroy(self.lmcache_engine_name)
|
||||
@@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
MooncakeStore Connector for Distributed Machine Learning Inference
|
||||
|
||||
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
|
||||
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
|
||||
database-style KVStore.
|
||||
"""
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MooncakeStoreConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
self.config = config.kv_transfer_config
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
|
||||
self.local_tp_rank = local_rank
|
||||
|
||||
# Init kv_store
|
||||
if self.config.kv_connector == "MooncakeStoreConnector":
|
||||
# Check if MOONCAKE_CONFIG_PATH is set
|
||||
import os
|
||||
use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None
|
||||
|
||||
if not use_mooncake_store:
|
||||
raise ValueError(
|
||||
"To use MooncakeStoreConnector, you need to pass the ENV: "
|
||||
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
|
||||
else:
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501
|
||||
MooncakeStore)
|
||||
logger.info(
|
||||
"Initializing KVStoreConnector under kv_transfer_config %s",
|
||||
self.config)
|
||||
self.kv_store = MooncakeStore(config)
|
||||
else:
|
||||
logger.error("Can not find %s", self.config.kv_connector)
|
||||
|
||||
assert self.kv_store is not None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the buffer and release resources.
|
||||
This method is responsible for cleaning up resources related to the
|
||||
connector when it is no longer needed.
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
self.kv_store.close()
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
|
||||
model_config = model_executable.model.config
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(hidden_size / num_attention_heads)
|
||||
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
store_key_prefix = self.tensor_hash(current_tokens)
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
||||
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
kvcache_to_sent = torch.stack((keys, values), dim=0)
|
||||
store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}"
|
||||
self.kv_store.put(store_kvcache_key, kvcache_to_sent)
|
||||
|
||||
hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}"
|
||||
self.kv_store.put(hidden_key,
|
||||
hidden_or_intermediate_states[start_pos:end_pos])
|
||||
|
||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
bypass_model_exec = True
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# This can happen during inflight batching. See:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You should set --enable_chunked_prefill=False "
|
||||
"and --max_num_batched_tokens "
|
||||
"should be equal to max_seq_len_to_capture")
|
||||
bypass_model_exec = False
|
||||
assert start_pos == num_prefill_tokens
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
|
||||
# get roi for current seq
|
||||
load_key_prefix = self.tensor_hash(current_tokens)
|
||||
load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}"
|
||||
remote_kv = self.kv_store.get(load_kvcache_key)
|
||||
hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}"
|
||||
hidden = self.kv_store.get(hidden_key)
|
||||
|
||||
if remote_kv is None or hidden is None:
|
||||
# didn't find any match.
|
||||
bypass_model_exec = False
|
||||
continue
|
||||
|
||||
num_computed_tokens = current_tokens.shape[0]
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# call self.kv_store to get kv layer by layer
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
layer = model_executable.model.layers[layer_id]
|
||||
# get kvcache object
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
# get remote kvcache
|
||||
|
||||
remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][
|
||||
layer_id]
|
||||
# use ops.reshape_and_cache_flash to put kv into kvcache
|
||||
ops.reshape_and_cache_flash(
|
||||
remote_k.to(key_cache.device),
|
||||
remote_v.to(value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
logger.warning(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
@staticmethod
|
||||
def tensor_hash(tensor: torch.Tensor) -> int:
|
||||
"""Calculate the hash value of the tensor."""
|
||||
tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes()
|
||||
hash_object = hashlib.blake2b(tensor_bytes)
|
||||
hash_hex = hash_object.hexdigest()
|
||||
return int(hash_hex[:16], 16)
|
||||
306
vllm/distributed/kv_transfer/kv_connector/p2p_connector.py
Normal file
306
vllm/distributed/kv_transfer/kv_connector/p2p_connector.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# Mainly adopted from https://github.com/FlagOpen/FlagScale/blob/44ceca57dd6f86b10163968e617497c613e47d6e/flagscale/backends/vllm/vllm/distributed/kv_transfer/kv_connector/p2p_connector.py.
|
||||
# Below is the original copyright:
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
if os.getenv("USE_FLAGCX", "false").lower() in ("1", "true"):
|
||||
from vllm.distributed.kv_transfer.kv_pipe.flagcx_p2p_nccl_pipe import P2pNcclPipe
|
||||
else:
|
||||
from vllm.distributed.kv_transfer.kv_pipe.p2p_nccl_pipe import P2pNcclPipe
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class P2pConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
self.rank = rank
|
||||
self.config = config.kv_transfer_config
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
||||
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
|
||||
|
||||
assert self.config.kv_connector == "P2pConnector"
|
||||
|
||||
self.lookup_buffer_size = self.config.kv_buffer_size
|
||||
|
||||
self.p2p_nccl_pipe = P2pNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
hostname="",
|
||||
port_offset=rank,
|
||||
)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
# input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
request_ids = list(model_input.request_ids_to_seq_ids.keys())
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
|
||||
model_config = model_executable.model.config
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
|
||||
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
||||
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
||||
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
||||
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
||||
# kv_lora_rank + qk_rope_head_dim].
|
||||
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
||||
# to a kv_cache shape of [2, num_blks, blk_size,
|
||||
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
||||
# For more details, see vllm/attention/backends/mla/common.py.
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
head_size = model_config.kv_lora_rank + \
|
||||
model_config.qk_rope_head_dim
|
||||
num_heads = 1
|
||||
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||
head_size = model_config.qk_nope_head_dim + \
|
||||
model_config.qk_rope_head_dim
|
||||
else:
|
||||
head_size = getattr(model_config, "head_dim",
|
||||
int(hidden_size // num_attention_heads))
|
||||
|
||||
# query_lens contains new KV caches that are added to vLLM.
|
||||
# so we will send them to decode instance
|
||||
# FIXME(Kuntai): This assume that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You have some decode requests while using "
|
||||
"SimpleConnector. Their KVCache won't be sent.")
|
||||
break
|
||||
|
||||
# current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
## 更改kv_cache 排布
|
||||
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
|
||||
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
else:
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
||||
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
|
||||
request_id = request_ids[idx]
|
||||
ip, port = self.parse_request_id(request_id, True)
|
||||
remote_address = ip + ":" + str(port + self.rank)
|
||||
|
||||
self.p2p_nccl_pipe.send_tensor(request_id + "keys", keys,
|
||||
remote_address)
|
||||
self.p2p_nccl_pipe.send_tensor(request_id + "values", values,
|
||||
remote_address)
|
||||
self.p2p_nccl_pipe.send_tensor(
|
||||
request_id + "hidden",
|
||||
hidden_or_intermediate_states[start_pos:end_pos],
|
||||
remote_address)
|
||||
|
||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
# When bypass_model_exec is set to False, it means that at least for one
|
||||
# request its corresponding KV cache or hidden state is missing.
|
||||
# In this case we need to do prefilling to recompute missing KV cache
|
||||
# and hidden states.
|
||||
bypass_model_exec = True
|
||||
|
||||
model_config = model_executable.model.config
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
request_ids = list(model_input.request_ids_to_seq_ids.keys())
|
||||
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
input_tokens_list = []
|
||||
num_computed_tokens_list = []
|
||||
start_pos_list = []
|
||||
|
||||
# enumerate different requests
|
||||
# FIXME(Kuntai): This impl assumes that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# This can happen during inflight batching. See:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You should set --enable_chunked_prefill=False "
|
||||
"and --max_num_batched_tokens "
|
||||
"should be equal to --max_seq_len_to_capture")
|
||||
bypass_model_exec = False
|
||||
assert start_pos == num_prefill_tokens
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
num_tokens = slen
|
||||
|
||||
# collecting data for rebuilding the input
|
||||
input_tokens_list.append(current_tokens)
|
||||
start_pos_list.append(start_pos)
|
||||
|
||||
request_id = request_ids[idx]
|
||||
ip, port = self.parse_request_id(request_id, False)
|
||||
remote_address = ip + ":" + str(port + self.rank)
|
||||
|
||||
keys = self.p2p_nccl_pipe.recv_tensor(request_id + "keys",
|
||||
remote_address)
|
||||
values = self.p2p_nccl_pipe.recv_tensor(request_id + "values",
|
||||
remote_address)
|
||||
hidden = self.p2p_nccl_pipe.recv_tensor(request_id + "hidden",
|
||||
remote_address)
|
||||
|
||||
num_computed_tokens = current_tokens.shape[0]
|
||||
num_computed_tokens_list.append(num_computed_tokens)
|
||||
|
||||
# check if both KV cache and the hidden states are received
|
||||
# If not, need to redo the forwarding to compute missing states
|
||||
if not all([(num_computed_tokens == num_tokens), keys is not None,
|
||||
values is not None, hidden is not None]):
|
||||
bypass_model_exec = False
|
||||
break
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# put received KV caches into paged memory
|
||||
for i in range(model_executable.model.start_layer,
|
||||
model_executable.model.end_layer):
|
||||
|
||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
||||
layer = model_executable.model.layers[i]
|
||||
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
layer.self_attn.attn = layer.self_attn.mla_attn
|
||||
k_c_normed_k_pe = keys[
|
||||
i - model_executable.model.start_layer].to(
|
||||
kv_cache.device).squeeze(1)
|
||||
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
|
||||
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
ops.reshape_and_cache_flash(
|
||||
keys[i - model_executable.model.start_layer].to(
|
||||
key_cache.device),
|
||||
values[i - model_executable.model.start_layer].to(
|
||||
value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
# Some of the KV cache is not retrieved
|
||||
# Here we will fall back to normal model forwarding
|
||||
# But optionally you can adjust model_input so that you only do
|
||||
# prefilling on those tokens that are missing KV caches.
|
||||
logger.warning(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
@staticmethod
|
||||
def parse_request_id(request_id: str, is_prefill=True) -> Tuple[str, int]:
|
||||
logger.debug("parse_request_id, request_id: %s, is_prefill: %s",
|
||||
request_id, is_prefill)
|
||||
# Regular expression to match the string hostname and integer port
|
||||
if is_prefill:
|
||||
pattern = r"___decode_addr_(.*):(\d+)"
|
||||
else:
|
||||
pattern = r"___prefill_addr_(.*):(\d+)___"
|
||||
|
||||
# Use re.search to find the pattern in the request_id
|
||||
match = re.search(pattern, request_id)
|
||||
if match:
|
||||
# Extract the ranks
|
||||
ip = match.group(1)
|
||||
port = int(match.group(2))
|
||||
|
||||
logger.debug("parse_request_id, request_id: %s, ip: %s, port: %s",
|
||||
request_id, ip, str(port))
|
||||
return ip, port
|
||||
raise ValueError(
|
||||
f"Request id {request_id} does not contain hostname and port")
|
||||
|
||||
def close(self):
|
||||
self.p2p_nccl_pipe.close()
|
||||
382
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Normal file
382
vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Simple KV Cache Connector for Distributed Machine Learning Inference
|
||||
|
||||
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
|
||||
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
|
||||
MooncakePipe.
|
||||
|
||||
But the logic can be extended to support other pipe and lookup buffer.
|
||||
"""
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
|
||||
SimpleBuffer)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SimpleConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
|
||||
self.config = config.kv_transfer_config
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
||||
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
|
||||
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
||||
PyNcclPipe)
|
||||
logger.info(
|
||||
"Initializing PyNcclConfig under kv_transfer_config %s",
|
||||
self.config)
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
# Check if MOONCAKE_CONFIG_PATH is set
|
||||
import os
|
||||
use_mooncake_distributed_pipe = os.getenv(
|
||||
'MOONCAKE_CONFIG_PATH') is not None
|
||||
|
||||
if not use_mooncake_distributed_pipe:
|
||||
raise ValueError(
|
||||
"To use MooncakeConnector, you need to pass the ENV: "
|
||||
"'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
|
||||
else:
|
||||
from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501
|
||||
MooncakePipe)
|
||||
logger.info(
|
||||
"Initializing MooncakeConfig under kv_transfer_config %s",
|
||||
self.config)
|
||||
|
||||
self.lookup_buffer_size = self.config.kv_buffer_size
|
||||
|
||||
self.producer_buffer: Optional[SimpleBuffer] = None
|
||||
self.consumer_buffer: Optional[SimpleBuffer] = None
|
||||
|
||||
self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
|
||||
# 2 pipes for every rank in the world
|
||||
port_offset_base = 2 * rank
|
||||
|
||||
# In disaggregated prefill, the prefill vLLM only uses send pipe
|
||||
# and the decode vLLM only uses recv pipe
|
||||
if self.config.is_kv_producer:
|
||||
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
self.producer_data_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base,
|
||||
)
|
||||
self.producer_signal_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base + 1,
|
||||
device="cpu",
|
||||
)
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
self.producer_data_pipe = MooncakePipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
)
|
||||
# We only need to initialize MooncakePipe once
|
||||
self.producer_signal_pipe = self.producer_data_pipe
|
||||
|
||||
self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
|
||||
self.producer_data_pipe,
|
||||
self.config.kv_buffer_size)
|
||||
|
||||
else:
|
||||
|
||||
# the current vLLM instance is KV consumer, so it needs to connect
|
||||
# its recv pipe to the send pipe of KV producder
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
self.consumer_data_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base,
|
||||
)
|
||||
self.consumer_signal_pipe = PyNcclPipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base + 1,
|
||||
device="cpu",
|
||||
)
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
self.consumer_data_pipe = MooncakePipe(
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
)
|
||||
self.consumer_signal_pipe = self.consumer_data_pipe
|
||||
|
||||
self.consumer_buffer = SimpleBuffer(
|
||||
self.consumer_signal_pipe,
|
||||
self.consumer_data_pipe,
|
||||
self.config.kv_buffer_size,
|
||||
)
|
||||
|
||||
def select(self, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
||||
|
||||
assert self.consumer_buffer is not None, "Please initialize the "\
|
||||
"consumer buffer before calling select."
|
||||
return self.consumer_buffer.drop_select(input_tokens, roi)
|
||||
|
||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor) -> None:
|
||||
|
||||
assert self.producer_buffer is not None, "Please initialize the "\
|
||||
"producer buffer before calling insert."
|
||||
|
||||
self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
|
||||
model_config = model_executable.model.config
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
|
||||
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
||||
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
||||
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
||||
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
||||
# kv_lora_rank + qk_rope_head_dim].
|
||||
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
||||
# to a kv_cache shape of [2, num_blks, blk_size,
|
||||
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
||||
# For more details, see vllm/attention/backends/mla/common.py.
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
head_size = model_config.kv_lora_rank + \
|
||||
model_config.qk_rope_head_dim
|
||||
num_heads = 1
|
||||
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||
head_size = model_config.qk_nope_head_dim + \
|
||||
model_config.qk_rope_head_dim
|
||||
else:
|
||||
head_size = getattr(model_config, "head_dim",
|
||||
int(hidden_size // num_attention_heads))
|
||||
|
||||
# query_lens contains new KV caches that are added to vLLM.
|
||||
# so we will send them to decode instance
|
||||
# FIXME(Kuntai): This assume that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You have some decode requests while using "
|
||||
"SimpleConnector. Their KVCache won't be sent.")
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
# TODO, fix this
|
||||
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
else:
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
||||
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
|
||||
self.insert(current_tokens,
|
||||
torch.ones_like(current_tokens,
|
||||
dtype=bool), keys, values,
|
||||
hidden_or_intermediate_states[start_pos:end_pos])
|
||||
|
||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
# When bypass_model_exec is set to False, it means that at least for one
|
||||
# request its corresponding KV cache or hidden state is missing.
|
||||
# In this case we need to do prefilling to recompute missing KV cache
|
||||
# and hidden states.
|
||||
bypass_model_exec = True
|
||||
|
||||
model_config = model_executable.model.config
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
input_tokens_list = []
|
||||
num_computed_tokens_list = []
|
||||
start_pos_list = []
|
||||
|
||||
# enumerate different requests
|
||||
# FIXME(Kuntai): This impl assumes that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# This can happen during inflight batching. See:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You should set --enable_chunked_prefill=False "
|
||||
"and --max_num_batched_tokens "
|
||||
"should be equal to --max_seq_len_to_capture")
|
||||
bypass_model_exec = False
|
||||
assert start_pos == num_prefill_tokens
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
num_tokens = slen
|
||||
|
||||
# collecting data for rebuilding the input
|
||||
input_tokens_list.append(current_tokens)
|
||||
start_pos_list.append(start_pos)
|
||||
|
||||
ret = self.select(current_tokens,
|
||||
torch.ones_like(current_tokens, dtype=bool))
|
||||
if ret[0] is None:
|
||||
# didn't find any match.
|
||||
bypass_model_exec = False
|
||||
num_computed_tokens_list.append(0)
|
||||
continue
|
||||
|
||||
roi: torch.Tensor = ret[1]
|
||||
keys: torch.Tensor = ret[2]
|
||||
values: torch.Tensor = ret[3]
|
||||
hidden: torch.Tensor = ret[4]
|
||||
|
||||
num_computed_tokens = roi.shape[0]
|
||||
num_computed_tokens_list.append(num_computed_tokens)
|
||||
|
||||
# check if both KV cache and the hidden states are received
|
||||
# If not, need to redo the forwarding to compute missing states
|
||||
if not all([(num_computed_tokens == num_tokens), hidden is not None
|
||||
]):
|
||||
bypass_model_exec = False
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# put received KV caches into paged memory
|
||||
for i in range(model_executable.model.start_layer,
|
||||
model_executable.model.end_layer):
|
||||
|
||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
||||
layer = model_executable.model.layers[i]
|
||||
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
layer.self_attn.attn = layer.self_attn.mla_attn
|
||||
k_c_normed_k_pe = keys[
|
||||
i - model_executable.model.start_layer].to(
|
||||
kv_cache.device).squeeze(1)
|
||||
k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
|
||||
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
ops.reshape_and_cache_flash(
|
||||
keys[i - model_executable.model.start_layer].to(
|
||||
key_cache.device),
|
||||
values[i - model_executable.model.start_layer].to(
|
||||
value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
# Some of the KV cache is not retrieved
|
||||
# Here we will fall back to normal model forwarding
|
||||
# But optionally you can adjust model_input so that you only do
|
||||
# prefilling on those tokens that are missing KV caches.
|
||||
logger.warning(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def close(self):
|
||||
self.producer_data_pipe.close()
|
||||
self.consumer_data_pipe.close()
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
self.producer_signal_pipe.close()
|
||||
self.consumer_signal_pipe.close()
|
||||
elif self.config.kv_connector == "MooncakeConnector":
|
||||
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
|
||||
# close the data_pipe.
|
||||
pass
|
||||
Reference in New Issue
Block a user