72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
from vllm import envs
|
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
|
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
|
KVConnectorFactory)
|
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
|
KVConnectorRole)
|
|
from vllm.distributed.parallel_state import get_world_group
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
|
|
_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
|
|
|
|
|
|
def get_kv_transfer_group() -> KVConnectorBaseType:
|
|
assert _KV_CONNECTOR_AGENT is not None, (
|
|
"disaggregated KV cache transfer parallel group is not initialized")
|
|
return _KV_CONNECTOR_AGENT
|
|
|
|
|
|
def has_kv_transfer_group() -> bool:
|
|
return _KV_CONNECTOR_AGENT is not None
|
|
|
|
|
|
def is_v1_kv_transfer_group(
|
|
connector: Optional[KVConnectorBaseType] = None) -> bool:
|
|
"""Check if the KV connector is the v1 connector.
|
|
If the argument is None, it will check the global KV connector
|
|
|
|
Args:
|
|
connector: The KV connector to check. If None, it will check the
|
|
global KV connector.
|
|
|
|
Note:
|
|
This function will no-longer be needed after the v1 KV connector
|
|
becomes the default.
|
|
"""
|
|
if connector is None:
|
|
connector = _KV_CONNECTOR_AGENT
|
|
|
|
if connector is None:
|
|
return False
|
|
|
|
return isinstance(connector, KVConnectorBase_V1)
|
|
|
|
|
|
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
|
"""
|
|
Initialize KV cache transfer parallel group.
|
|
"""
|
|
|
|
global _KV_CONNECTOR_AGENT
|
|
|
|
if vllm_config.kv_transfer_config is None:
|
|
return
|
|
|
|
if (vllm_config.kv_transfer_config.is_kv_transfer_instance
|
|
and _KV_CONNECTOR_AGENT is None):
|
|
if envs.VLLM_USE_V1:
|
|
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
|
|
config=vllm_config, role=KVConnectorRole.WORKER)
|
|
else:
|
|
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
|
|
rank=get_world_group().rank,
|
|
local_rank=get_world_group().local_rank,
|
|
config=vllm_config,
|
|
)
|