[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #5) (#5996)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|
`.../distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py`
|
|
`vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py` |
| `
vllm_ascend/distributed/kv_transfer/utils/mooncake_transfer_engine.py` |
| ` vllm_ascend/distributed/kv_transfer/utils/utils.py` |
| ` vllm_ascend/kv_offload/cpu_npu.py` |
| ` vllm_ascend/kv_offload/npu.py` |
| ` vllm_ascend/lora/lora_ops.py` |
| ` vllm_ascend/lora/punica_npu.py` |
| ` vllm_ascend/lora/utils.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-24 22:45:38 +08:00
committed by GitHub
parent 7faa6878a6
commit 6ccccad102
21 changed files with 866 additions and 1034 deletions

View File

@@ -1,6 +1,8 @@
[mypy] [mypy]
; warn_return_any = True ; warn_return_any = True
warn_unused_configs = True warn_unused_configs = True
; disable errors about unchecked annotations for now.
disable_error_code = annotation-unchecked
; Suppress all missing import errors from torch_npu for mypy. ; Suppress all missing import errors from torch_npu for mypy.
[mypy-torch_npu.*] [mypy-torch_npu.*]
@@ -31,4 +33,4 @@ ignore_missing_imports = True
ignore_missing_imports = True ignore_missing_imports = True
[mypy-ucm.*] [mypy-ucm.*]
ignore_missing_imports = True ignore_missing_imports = True

View File

@@ -51,11 +51,6 @@ line-length = 120
# Folder to be modified # Folder to be modified
exclude = [ exclude = [
"tests/**", "tests/**",
# (5)
"vllm_ascend/distributed/kv_transfer/kv_pool/**",
"vllm_ascend/distributed/kv_transfer/utils/**",
"vllm_ascend/kv_offload/**",
"vllm_ascend/lora/**",
# (7) # (7)
"vllm_ascend/quantization/**", "vllm_ascend/quantization/**",
"vllm_ascend/sample/*.py", "vllm_ascend/sample/*.py",

View File

@@ -1,11 +1,10 @@
import threading import threading
from typing import Any, Optional from typing import Any
import torch import torch
import zmq import zmq
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.logger import logger from vllm.logger import logger
from vllm.utils.network_utils import make_zmq_socket from vllm.utils.network_utils import make_zmq_socket
@@ -17,31 +16,27 @@ from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder from vllm.v1.serial_utils import MsgpackDecoder
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler import ( from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler import (
KVPoolScheduler, get_zmq_rpc_path_lookup) KVPoolScheduler,
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import \ get_zmq_rpc_path_lookup,
KVPoolWorker )
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import KVPoolWorker
class AscendStoreConnector(KVConnectorBase_V1): class AscendStoreConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None):
def __init__(self, super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
super().__init__(vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config)
self.kv_role = vllm_config.kv_transfer_config.kv_role self.kv_role = vllm_config.kv_transfer_config.kv_role
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get("use_layerwise", False)
"use_layerwise", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False) "consumer_is_to_put", False
)
connector_name = vllm_config.kv_transfer_config.kv_connector connector_name = vllm_config.kv_transfer_config.kv_connector
if connector_name == "MooncakeConnectorStoreV1": if connector_name == "MooncakeConnectorStoreV1":
logger.warning( logger.warning(
"It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future." "It is recommended to use the AscendStoreConnector, "
"as the MoonCakeStoreConnector will be removed in the future."
) )
self.kv_caches: dict[str, torch.Tensor] = {} self.kv_caches: dict[str, torch.Tensor] = {}
@@ -49,8 +44,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
self.sended_but_unfinished_reqs: set[str] = set() self.sended_but_unfinished_reqs: set[str] = set()
if role == KVConnectorRole.SCHEDULER: if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = KVPoolScheduler(vllm_config, self.connector_scheduler = KVPoolScheduler(vllm_config, self.use_layerwise)
self.use_layerwise)
else: else:
self.connector_worker = KVPoolWorker( self.connector_worker = KVPoolWorker(
vllm_config, vllm_config,
@@ -59,27 +53,19 @@ class AscendStoreConnector(KVConnectorBase_V1):
assert self.connector_worker is not None assert self.connector_worker is not None
if vllm_config.parallel_config.rank == 0: if vllm_config.parallel_config.rank == 0:
self.lookup_server = LookupKeyServer(self.connector_worker, self.lookup_server = LookupKeyServer(self.connector_worker, vllm_config, self.use_layerwise)
vllm_config,
self.use_layerwise)
############################################################ ############################################################
# Scheduler Side Methods # Scheduler Side Methods
############################################################ ############################################################
def get_num_new_matched_tokens( def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens( return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens)
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request", def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc( return self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens)
request, blocks, num_external_tokens)
def build_connector_meta( def build_connector_meta(
self, self,
@@ -92,7 +78,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
block_ids: list[int], block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]: ) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids) return self.connector_scheduler.request_finished(request, block_ids)
@@ -103,8 +89,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
assert self.connector_worker is not None assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches) self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext", def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
**kwargs) -> None:
assert self.connector_worker is not None assert self.connector_worker is not None
self.connector_worker.start_load_kv(self._get_connector_metadata()) self.connector_worker.start_load_kv(self._get_connector_metadata())
@@ -113,8 +98,9 @@ class AscendStoreConnector(KVConnectorBase_V1):
return return
self.connector_worker.wait_for_layer_load() self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, def save_kv_layer(
attn_metadata: "AttentionMetadata", **kwargs) -> None: self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
) -> None:
if not self.use_layerwise: if not self.use_layerwise:
return return
@@ -133,17 +119,16 @@ class AscendStoreConnector(KVConnectorBase_V1):
self.connector_worker.wait_for_save(self._get_connector_metadata()) self.connector_worker.wait_for_save(self._get_connector_metadata())
def get_finished(self, def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests.""" """Get the finished recving and sending requests."""
assert self.connector_worker is not None assert self.connector_worker is not None
done_sending, done_recving = self.connector_worker.get_finished( done_sending, done_recving = self.connector_worker.get_finished(
finished_req_ids, self._get_connector_metadata()) finished_req_ids, self._get_connector_metadata()
)
return done_sending, done_recving return done_sending, done_recving
class LookupKeyServer: class LookupKeyServer:
def __init__( def __init__(
self, self,
pool_worker: KVPoolWorker, pool_worker: KVPoolWorker,
@@ -171,8 +156,7 @@ class LookupKeyServer:
token_len = int.from_bytes(all_frames[0], byteorder="big") token_len = int.from_bytes(all_frames[0], byteorder="big")
hash_frames = all_frames[1:] hash_frames = all_frames[1:]
hashes_str = self.decoder.decode(hash_frames) hashes_str = self.decoder.decode(hash_frames)
result = self.pool_worker.lookup_scheduler( result = self.pool_worker.lookup_scheduler(token_len, hashes_str, self.use_layerwise)
token_len, hashes_str, self.use_layerwise)
response = result.to_bytes(4, "big") response = result.to_bytes(4, "big")
self.socket.send(response) self.socket.send(response)

View File

@@ -4,13 +4,15 @@ from vllm.config import ParallelConfig
class Backend(ABC): class Backend(ABC):
@abstractmethod
def __init__(self, parallel_config: ParallelConfig): def __init__(self, parallel_config: ParallelConfig):
pass pass
@abstractmethod
def set_device(self): def set_device(self):
pass pass
@abstractmethod
def register_buffer(self, ptrs: list[int], lengths: list[int]): def register_buffer(self, ptrs: list[int], lengths: list[int]):
pass pass
@@ -19,11 +21,9 @@ class Backend(ABC):
pass pass
@abstractmethod @abstractmethod
def put(self, keys: list[str], addrs: list[list[int]], def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
sizes: list[list[int]]):
pass pass
@abstractmethod @abstractmethod
def get(self, keys: list[str], addrs: list[list[int]], def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
sizes: list[list[int]]):
pass pass

View File

@@ -5,8 +5,7 @@ import torch
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import logger from vllm.logger import logger
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
Backend
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
@@ -18,7 +17,6 @@ class MmcDirect(Enum):
class MemcacheBackend(Backend): class MemcacheBackend(Backend):
def __init__(self, parallel_config: ParallelConfig): def __init__(self, parallel_config: ParallelConfig):
try: try:
from memcache_hybrid import DistributedObjectStore # type: ignore from memcache_hybrid import DistributedObjectStore # type: ignore
@@ -26,21 +24,17 @@ class MemcacheBackend(Backend):
raise ImportError( raise ImportError(
"Please install memcache by following the instructions at " "Please install memcache by following the instructions at "
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501 "https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
"to run vLLM with MemcacheConnector.") from e "to run vLLM with MemcacheConnector."
) from e
try: try:
soc_version = get_ascend_device_type() soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}: if soc_version in {AscendDeviceType.A2}:
import torch import torch
from vllm.distributed import get_world_group from vllm.distributed import get_world_group
tmp_tensor = torch.zeros(1, device="npu") tmp_tensor = torch.zeros(1, device="npu")
output_tensor_list = [ output_tensor_list = [torch.empty_like(tmp_tensor) for _ in range(torch.distributed.get_world_size())]
torch.empty_like(tmp_tensor) torch.distributed.all_gather(output_tensor_list, tmp_tensor, group=get_world_group().device_group)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(
output_tensor_list,
tmp_tensor,
group=get_world_group().device_group)
self.rank = parallel_config.rank self.rank = parallel_config.rank
self.store = DistributedObjectStore() self.store = DistributedObjectStore()
res = self.store.init(self.rank) res = self.store.init(self.rank)
@@ -54,8 +48,7 @@ class MemcacheBackend(Backend):
logger.error("Configuration loading failed: %s", e) logger.error("Configuration loading failed: %s", e)
raise raise
except Exception as exc: except Exception as exc:
logger.error( logger.error("An error occurred while loading the configuration: %s", exc)
"An error occurred while loading the configuration: %s", exc)
raise raise
def set_device(self): def set_device(self):
@@ -73,22 +66,18 @@ class MemcacheBackend(Backend):
def exists(self, keys: list[str]) -> list[int]: def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys) return self.store.batch_is_exist(keys)
def get(self, key: list[str], addr: list[list[int]], def get(self, key: list[str], addr: list[list[int]], size: list[list[int]]):
size: list[list[int]]):
try: try:
res = self.store.batch_get_into_layers(key, addr, size, res = self.store.batch_get_into_layers(key, addr, size, MmcDirect.COPY_G2L.value)
MmcDirect.COPY_G2L.value)
for value in res: for value in res:
if value != 0: if value != 0:
logger.error(f"Failed to get key {key},res:{res}") logger.error(f"Failed to get key {key},res:{res}")
except Exception as e: except Exception as e:
logger.error(f"Failed to get key {key}. {e}") logger.error(f"Failed to get key {key}. {e}")
def put(self, key: list[str], addr: list[list[int]], def put(self, key: list[str], addr: list[list[int]], size: list[list[int]]):
size: list[list[int]]):
try: try:
res = self.store.batch_put_from_layers(key, addr, size, res = self.store.batch_put_from_layers(key, addr, size, MmcDirect.COPY_L2G.value)
MmcDirect.COPY_L2G.value)
for value in res: for value in res:
if value != 0: if value != 0:
logger.error(f"Failed to get key {key},res:{res}") logger.error(f"Failed to get key {key},res:{res}")

View File

@@ -2,10 +2,9 @@
import json import json
import os import os
import re import re
import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union
import torch
# Third Party # Third Party
from mooncake.store import ReplicateConfig # type: ignore from mooncake.store import ReplicateConfig # type: ignore
@@ -13,17 +12,14 @@ from vllm.config import ParallelConfig
from vllm.logger import logger from vllm.logger import logger
from vllm.utils.network_utils import get_ip from vllm.utils.network_utils import get_ip
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
Backend from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import \
global_te
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
class MooncakeBackend(Backend): class MooncakeBackend(Backend):
def __init__(self, parallel_config: ParallelConfig): def __init__(self, parallel_config: ParallelConfig):
try: try:
from mooncake.store import MooncakeDistributedStore # type: ignore from mooncake.store import MooncakeDistributedStore # type: ignore
@@ -31,23 +27,25 @@ class MooncakeBackend(Backend):
raise ImportError( raise ImportError(
"Please install mooncake by following the instructions at " "Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e "to run vLLM with MooncakeConnector."
) from e
self.config = MooncakeStoreConfig.load_from_env() self.config = MooncakeStoreConfig.load_from_env()
self.store = MooncakeDistributedStore() self.store = MooncakeDistributedStore()
self.rank = parallel_config.rank self.rank = parallel_config.rank
if self.config.protocol == "ascend": if self.config.protocol == "ascend":
local_hostname = get_ip() local_hostname = get_ip()
transfer_engine = global_te.get_transfer_engine(local_hostname, transfer_engine = global_te.get_transfer_engine(local_hostname, device_name=None)
device_name=None) self.local_seg = local_hostname + ":" + str(transfer_engine.get_rpc_port())
self.local_seg = local_hostname + ":" + str( ret = self.store.setup(
transfer_engine.get_rpc_port()) self.local_seg,
ret = self.store.setup(self.local_seg, self.config.metadata_server, self.config.metadata_server,
self.config.global_segment_size, self.config.global_segment_size,
self.config.local_buffer_size, self.config.local_buffer_size,
self.config.protocol, self.config.protocol,
self.config.device_name, self.config.device_name,
self.config.master_server_address, self.config.master_server_address,
transfer_engine.get_engine()) transfer_engine.get_engine(),
)
if ret != 0: if ret != 0:
msg = "Initialize mooncake failed." msg = "Initialize mooncake failed."
logger.error(msg) logger.error(msg)
@@ -63,25 +61,21 @@ class MooncakeBackend(Backend):
def exists(self, keys: list[str]) -> list[int]: def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys) return self.store.batch_is_exist(keys)
def put(self, keys: list[str], addrs: list[list[int]], def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
sizes: list[list[int]]):
try: try:
config = ReplicateConfig() config = ReplicateConfig()
config.preferred_segment = self.local_seg config.preferred_segment = self.local_seg
config.prefer_alloc_in_same_node = True config.prefer_alloc_in_same_node = True
res = self.store.batch_put_from_multi_buffers( res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes, config)
keys, addrs, sizes, config)
for value in res: for value in res:
if value < 0: if value < 0:
logger.error(f"Failed to put key {keys},res:{res}") logger.error(f"Failed to put key {keys},res:{res}")
except Exception as e: except Exception as e:
logger.error(f"Failed to put key {keys},error:{e}") logger.error(f"Failed to put key {keys},error:{e}")
def get(self, keys: list[str], addrs: list[list[int]], def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
sizes: list[list[int]]):
try: try:
res = self.store.batch_get_into_multi_buffers( res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes, True)
keys, addrs, sizes, True)
for value in res: for value in res:
if value < 0: if value < 0:
logger.error(f"Failed to get key {keys}, res:{res}") logger.error(f"Failed to get key {keys}, res:{res}")
@@ -92,7 +86,7 @@ class MooncakeBackend(Backend):
@dataclass @dataclass
class MooncakeStoreConfig: class MooncakeStoreConfig:
metadata_server: str metadata_server: str
global_segment_size: Union[int, str] global_segment_size: int | str
local_buffer_size: int local_buffer_size: int
protocol: str protocol: str
device_name: str device_name: str
@@ -105,33 +99,32 @@ class MooncakeStoreConfig:
return MooncakeStoreConfig( return MooncakeStoreConfig(
metadata_server=config.get("metadata_server"), metadata_server=config.get("metadata_server"),
global_segment_size=_parse_global_segment_size( global_segment_size=_parse_global_segment_size(
config.get("global_segment_size", config.get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE)
DEFAULT_GLOBAL_SEGMENT_SIZE)), ),
local_buffer_size=_parse_global_segment_size( local_buffer_size=_parse_global_segment_size(config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)),
config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)),
protocol=config.get("protocol", "ascend"), protocol=config.get("protocol", "ascend"),
device_name=config.get("device_name", ""), device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address")) master_server_address=config.get("master_server_address"),
)
@staticmethod @staticmethod
def load_from_env() -> "MooncakeStoreConfig": def load_from_env() -> "MooncakeStoreConfig":
config_path = os.getenv("MOONCAKE_CONFIG_PATH") config_path = os.getenv("MOONCAKE_CONFIG_PATH")
if not config_path: if not config_path:
raise ValueError( raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeStoreConfig.from_file(config_path) return MooncakeStoreConfig.from_file(config_path)
def _parse_global_segment_size(value) -> int: def _parse_global_segment_size(value) -> int:
""" """
Parse storage size strings with support for units: GB, MB, KB, B Parse storage size strings with support for units: GB, MB, KB, B
Args: Args:
value: Input value (int, str, or other convertible types) value: Input value (int, str, or other convertible types)
Returns: Returns:
int: Size in bytes int: Size in bytes
Raises: Raises:
ValueError: For invalid format, missing number, or negative values ValueError: For invalid format, missing number, or negative values
TypeError: For unsupported input types TypeError: For unsupported input types
@@ -143,54 +136,50 @@ def _parse_global_segment_size(value) -> int:
try: try:
return int(value) return int(value)
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
raise TypeError( raise TypeError(f"Unsupported type for global_segment_size: {type(value)}") from e
f"Unsupported type for global_segment_size: {type(value)}"
) from e
cleaned_input = value.strip().lower() cleaned_input = value.strip().lower()
if not cleaned_input: if not cleaned_input:
raise ValueError("global segment size cannot be empty.") raise ValueError("global segment size cannot be empty.")
UNIT_MULTIPLIERS = { UNIT_MULTIPLIERS = {
'gb': 1024**3, # 1 GB = 1024^3 bytes "gb": 1024**3, # 1 GB = 1024^3 bytes
'mb': 1024**2, # 1 MB = 1024^2 bytes "mb": 1024**2, # 1 MB = 1024^2 bytes
'kb': 1024, # 1 KB = 1024 bytes "kb": 1024, # 1 KB = 1024 bytes
'b': 1 # 1 B = 1 byte "b": 1, # 1 B = 1 byte
} }
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$' pattern = r"^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$"
match = re.match(pattern, cleaned_input) match = re.match(pattern, cleaned_input)
if not match: if not match:
raise ValueError(f"Invalid format: '{value}'") raise ValueError(f"Invalid format: '{value}'")
number_str = match.group(1) number_str = match.group(1)
unit = match.group(2) or 'b' unit = match.group(2) or "b"
multiplier = UNIT_MULTIPLIERS[unit] multiplier = UNIT_MULTIPLIERS[unit]
return _convert_to_bytes(number_str, multiplier, value) return _convert_to_bytes(number_str, multiplier, value)
def _convert_to_bytes(number_str: str, multiplier: int, def _convert_to_bytes(number_str: str, multiplier: int, original_input: str) -> int:
original_input: str) -> int:
""" """
Convert numeric string to byte count Convert numeric string to byte count
Args: Args:
number_str: Numeric portion of input number_str: Numeric portion of input
multiplier: Unit conversion factor multiplier: Unit conversion factor
original_input: Original input string (for error messages) original_input: Original input string (for error messages)
Returns: Returns:
int: Byte count int: Byte count
Raises: Raises:
ValueError: For invalid numbers or negative results ValueError: For invalid numbers or negative results
""" """
try: try:
numeric_value = float(number_str) numeric_value = float(number_str)
except ValueError: except ValueError:
raise ValueError( raise ValueError(f"Invalid numeric value '{number_str}' in: '{original_input}'")
f"Invalid numeric value '{number_str}' in: '{original_input}'")
# Calculate byte count # Calculate byte count
try: try:
byte_count = int(numeric_value * multiplier) byte_count = int(numeric_value * multiplier)

View File

@@ -1,16 +1,16 @@
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union from typing import Optional
import torch import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
KVConnectorMetadata
from vllm.logger import logger from vllm.logger import logger
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import NewRequestData from vllm.v1.core.sched.output import NewRequestData
#Parameters related to the key # Parameters related to the key
@dataclass @dataclass
class KeyMetadata: class KeyMetadata:
"""name of the LLM model""" """name of the LLM model"""
@@ -32,23 +32,26 @@ class PoolKey:
chunk_hash: str chunk_hash: str
def __hash__(self): def __hash__(self):
return hash(( return hash(
self.key_metadata.model_name, (
self.key_metadata.head_or_tp_rank, self.key_metadata.model_name,
self.key_metadata.pcp_rank, self.key_metadata.head_or_tp_rank,
self.key_metadata.dcp_rank, self.key_metadata.pcp_rank,
self.key_metadata.pp_rank, self.key_metadata.dcp_rank,
self.chunk_hash, self.key_metadata.pp_rank,
)) self.chunk_hash,
)
)
def to_string(self): def to_string(self):
return ( return (
f"{self.key_metadata.model_name}" f"{self.key_metadata.model_name}"
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}" f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}" f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}"
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}") f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}"
)
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]: def split_layers(self, num_layers: int) -> list["LayerPoolKey"]:
"""Split the key into multiple keys for each layer""" """Split the key into multiple keys for each layer"""
keys = [] keys = []
for layer_id in range(num_layers): for layer_id in range(num_layers):
@@ -57,7 +60,8 @@ class PoolKey:
self.key_metadata, self.key_metadata,
self.chunk_hash, self.chunk_hash,
layer_id, layer_id,
)) )
)
return keys return keys
@@ -68,14 +72,16 @@ class LayerPoolKey(PoolKey):
layer_id: int layer_id: int
def __hash__(self): def __hash__(self):
return hash(( return hash(
self.key_metadata.model_name, (
self.key_metadata.head_or_tp_rank, self.key_metadata.model_name,
self.key_metadata.pcp_rank, self.key_metadata.head_or_tp_rank,
self.key_metadata.dcp_rank, self.key_metadata.pcp_rank,
self.chunk_hash, self.key_metadata.dcp_rank,
self.layer_id, self.chunk_hash,
)) self.layer_id,
)
)
def to_string(self): def to_string(self):
return ( return (
@@ -85,10 +91,8 @@ class LayerPoolKey(PoolKey):
) )
class ChunkedTokenDatabase(): class ChunkedTokenDatabase:
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, partitions: list[int] | None):
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
partitions: Optional[List[int]]):
self.metadata = metadata self.metadata = metadata
self.block_size = block_size self.block_size = block_size
self.use_mla = use_mla self.use_mla = use_mla
@@ -96,9 +100,7 @@ class ChunkedTokenDatabase():
self.block_len: list[int] = [] self.block_len: list[int] = []
self.partitions = partitions self.partitions = partitions
def _make_key_by_hash(self, def _make_key_by_hash(self, chunk_hash: str, layer_id: int | None = None):
chunk_hash: str,
layer_id: Optional[int] = None):
assert self.metadata is not None assert self.metadata is not None
return PoolKey( return PoolKey(
self.metadata, self.metadata,
@@ -116,8 +118,7 @@ class ChunkedTokenDatabase():
size_list = [] size_list = []
block_id = block_ids[start // self.block_size] block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr): for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2] block_len = self.block_len[index % 2] if self.use_mla else self.block_len[0]
if self.use_mla else self.block_len[0])
addr = base_addr + block_id * block_len addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start)) length = int(block_len / self.block_size * (end - start))
@@ -125,22 +126,17 @@ class ChunkedTokenDatabase():
size_list.append(length) size_list.append(length)
return addr_list, size_list, block_id return addr_list, size_list, block_id
def prepare_value_layer(self, start: int, end: int, block_ids: list[int], def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int):
layer_id: int):
block_id = block_ids[start // self.block_size] block_id = block_ids[start // self.block_size]
if self.use_mla: if self.use_mla:
addr_k = self.kv_caches_base_addr[layer_id * addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
2] + block_id * self.block_len[0] addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[1]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[1]
length_k = int(self.block_len[0] / self.block_size * (end - start)) length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start)) length_v = int(self.block_len[1] / self.block_size * (end - start))
size_list = [length_k, length_v] size_list = [length_k, length_v]
else: else:
addr_k = self.kv_caches_base_addr[layer_id * addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
2] + block_id * self.block_len[0] addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start)) length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length] size_list = [length, length]
addr_list = [addr_k, addr_v] addr_list = [addr_k, addr_v]
@@ -149,9 +145,9 @@ class ChunkedTokenDatabase():
def process_tokens( def process_tokens(
self, self,
token_len: int, token_len: int,
block_hashes: Union[list[BlockHash], list[str]], block_hashes: list[BlockHash] | list[str],
mask_num: int = 0, mask_num: int = 0,
) -> Iterable[Tuple[int, int, PoolKey]]: ) -> Iterable[tuple[int, int, PoolKey]]:
"""Process the tokens and return the corresponding cache engine keys. """Process the tokens and return the corresponding cache engine keys.
:param Union[torch.Tensor, List[int]] tokens: The tokens to process. :param Union[torch.Tensor, List[int]] tokens: The tokens to process.
@@ -202,10 +198,10 @@ class ChunkedTokenDatabase():
start = 0 start = 0
for j, part in enumerate(self.partitions): for j, part in enumerate(self.partitions):
# part * 2 because addr and size contain both k and v # part * 2 because addr and size contain both k and v
end = len(addr_list) if j == len( end = len(addr_list) if j == len(self.partitions) - 1 else start + part * 2
self.partitions) - 1 else start + part * 2
new_str = key[i].replace( # type: ignore[attr-defined] new_str = key[i].replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{j}", 1) "@pp_rank:0", f"@pp_rank:{j}", 1
)
new_key.append(new_str) new_key.append(new_str)
new_addr.append(addr_list[start:end]) new_addr.append(addr_list[start:end])
new_size.append(size_list[start:end]) new_size.append(size_list[start:end])
@@ -213,7 +209,7 @@ class ChunkedTokenDatabase():
return new_key, new_addr, new_size return new_key, new_addr, new_size
#Parameters related to the connector metadata # Parameters related to the connector metadata
@dataclass @dataclass
class LoadSpec: class LoadSpec:
# Number of tokens cached in vLLM # Number of tokens cached in vLLM
@@ -273,7 +269,7 @@ class RequestTracker:
def update( def update(
self, self,
new_block_ids: Union[tuple[list[int], ...], list[int]], new_block_ids: tuple[list[int], ...] | list[int],
) -> None: ) -> None:
"""Update the request tracker when a running request is """Update the request tracker when a running request is
scheduled again scheduled again
@@ -286,8 +282,7 @@ class RequestTracker:
elif isinstance(new_block_ids, list): elif isinstance(new_block_ids, list):
pass pass
else: else:
raise ValueError( raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}")
f"Unsupported new_block_ids type {type(new_block_ids)}")
self.allocated_block_ids.extend(new_block_ids) self.allocated_block_ids.extend(new_block_ids)
@@ -302,22 +297,22 @@ class ReqMeta:
block_hashes: list[BlockHash] block_hashes: list[BlockHash]
can_save: Optional[bool] = None can_save: bool | None = None
# load_spec # load_spec
load_spec: Optional[LoadSpec] = None load_spec: LoadSpec | None = None
is_last_chunk: Optional[bool] = None is_last_chunk: bool | None = None
current_event: Optional[torch.npu.Event] = None current_event: torch.npu.Event | None = None
@staticmethod @staticmethod
def from_request_tracker( def from_request_tracker(
tracker: RequestTracker, tracker: RequestTracker,
block_size: int, block_size: int,
load_spec: Optional[LoadSpec] = None, load_spec: LoadSpec | None = None,
skip_save: Optional[bool] = False, skip_save: bool | None = False,
block_hashes: list[BlockHash] = [], block_hashes: list[BlockHash] | None = None,
is_last_chunk: Optional[bool] = None, is_last_chunk: bool | None = None,
discard_partial_chunks: bool = True, discard_partial_chunks: bool = True,
) -> Optional["ReqMeta"]: ) -> Optional["ReqMeta"]:
"""Create the request metadata from a request tracker. """Create the request metadata from a request tracker.
@@ -333,17 +328,17 @@ class ReqMeta:
the request metadata if we need to perform load/save the request metadata if we need to perform load/save
operations, None otherwise. operations, None otherwise.
""" """
if block_hashes is None:
block_hashes = []
input_token_len = tracker.token_len input_token_len = tracker.token_len
# For save operation: do not save if the following condition is met # For save operation: do not save if the following condition is met
# 1. has already been saved before (num_saved_tokens > 0) # 1. has already been saved before (num_saved_tokens > 0)
# 2. number of unsaved tokens is not reached the chunk boundary # 2. number of unsaved tokens is not reached the chunk boundary
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) * chunk_boundary = cdiv(tracker.num_saved_tokens + 1, block_size) * block_size if discard_partial_chunks else 0
block_size if discard_partial_chunks else 0)
# Calculate number of tokens to save based on discard_partial_chunks # Calculate number of tokens to save based on discard_partial_chunks
# setting # setting
num_tokens_to_save = ((input_token_len // block_size * block_size) num_tokens_to_save = (input_token_len // block_size * block_size) if discard_partial_chunks else input_token_len
if discard_partial_chunks else input_token_len)
skip_save = skip_save or num_tokens_to_save < chunk_boundary skip_save = skip_save or num_tokens_to_save < chunk_boundary
if skip_save and load_spec is None: if skip_save and load_spec is None:
@@ -363,9 +358,7 @@ class ReqMeta:
else: else:
# Do not load if not in `can_load` state # Do not load if not in `can_load` state
load_spec = None load_spec = None
logger.debug( logger.debug(f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}")
f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}"
)
return ReqMeta( return ReqMeta(
req_id=tracker.req_id, req_id=tracker.req_id,
token_len_chunk=num_tokens_to_save, token_len_chunk=num_tokens_to_save,
@@ -378,7 +371,6 @@ class ReqMeta:
class AscendConnectorMetadata(KVConnectorMetadata): class AscendConnectorMetadata(KVConnectorMetadata):
def __init__(self, unfinished_request_ids, preempted_req_ids): def __init__(self, unfinished_request_ids, preempted_req_ids):
self.requests = [] self.requests = []
self.unfinished_request_ids = unfinished_request_ids self.unfinished_request_ids = unfinished_request_ids
@@ -396,10 +388,10 @@ class AscendConnectorMetadata(KVConnectorMetadata):
@dataclass @dataclass
class LasyerMultiBlockReqMeta: class LasyerMultiBlockReqMeta:
req_id: str req_id: str
keys: List[LayerPoolKey] keys: list[LayerPoolKey]
starts: List[int] starts: list[int]
ends: list[int] ends: list[int]
block_ids: list[int] block_ids: list[int]
layer_id: int layer_id: int
is_last_chunk: Optional[bool] = True is_last_chunk: bool | None = True
current_event: Optional[torch.npu.Event] = None current_event: torch.npu.Event | None = None

View File

@@ -7,8 +7,7 @@ from typing import Any
import torch import torch
from vllm.logger import logger from vllm.logger import logger
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
Backend
# isort: off # isort: off
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
@@ -20,10 +19,16 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import
class KVTransferThread(threading.Thread): class KVTransferThread(threading.Thread):
def __init__(
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, self,
block_size: int, tp_rank: int, dcp_size: int, m_store: Backend,
ready_event: threading.Event, name: str): token_database: ChunkedTokenDatabase,
block_size: int,
tp_rank: int,
dcp_size: int,
ready_event: threading.Event,
name: str,
):
super().__init__(daemon=True, name=name) super().__init__(daemon=True, name=name)
self.m_store = m_store self.m_store = m_store
self.ready_event = ready_event self.ready_event = ready_event
@@ -39,7 +44,7 @@ class KVTransferThread(threading.Thread):
def add_request( def add_request(
self, self,
request: ReqMeta, request: ReqMeta | LasyerMultiBlockReqMeta,
) -> torch.Tensor: ) -> torch.Tensor:
self.request_queue.put(request) self.request_queue.put(request)
@@ -98,17 +103,20 @@ class KVTransferThread(threading.Thread):
class KVCacheStoreSendingThread(KVTransferThread): class KVCacheStoreSendingThread(KVTransferThread):
def __init__(
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, self,
block_size: int, tp_rank: int, dcp_size: int, put_step: int, m_store: Backend,
kv_role: str, ready_event: threading.Event): token_database: ChunkedTokenDatabase,
super().__init__(m_store, block_size: int,
token_database, tp_rank: int,
block_size, dcp_size: int,
tp_rank, put_step: int,
dcp_size, kv_role: str,
ready_event, ready_event: threading.Event,
name="KVCacheSendingThread") ):
super().__init__(
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheSendingThread"
)
self.put_step = put_step self.put_step = put_step
self.kv_role = kv_role self.kv_role = kv_role
self.stored_requests = defaultdict[str, int](int) self.stored_requests = defaultdict[str, int](int)
@@ -139,16 +147,15 @@ class KVCacheStoreSendingThread(KVTransferThread):
self.request_queue.task_done() self.request_queue.task_done()
return return
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes):
token_len, req_meta.block_hashes):
starts.append(start) starts.append(start)
ends.append(end) ends.append(end)
keys.append(key.to_string()) keys.append(key.to_string())
if not self.dcp_size > 1: if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step] starts = starts[self.tp_rank % self.put_step :: self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step] ends = ends[self.tp_rank % self.put_step :: self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step] keys = keys[self.tp_rank % self.put_step :: self.put_step]
if not keys: if not keys:
self.dec_stored_request(req_id) self.dec_stored_request(req_id)
@@ -165,8 +172,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
keys = keys[skip_block_num:] keys = keys[skip_block_num:]
logger.info( logger.info(
"Storing KV cache for %d out of %d blocks " "Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
"(skip_block_num=%d) for request %s",
len(keys), len(keys),
token_len // self.block_size, token_len // self.block_size,
skip_block_num, skip_block_num,
@@ -183,14 +189,12 @@ class KVCacheStoreSendingThread(KVTransferThread):
addrs = [] addrs = []
sizes = [] sizes = []
for index, start in enumerate(starts): for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value( addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
start, ends[index], block_ids)
addrs.append(addr) addrs.append(addr)
sizes.append(size) sizes.append(size)
if self.kv_role == "kv_consumer": if self.kv_role == "kv_consumer":
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp( keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(keys, addrs, sizes)
keys, addrs, sizes)
if current_event is not None: if current_event is not None:
current_event.synchronize() current_event.synchronize()
@@ -201,69 +205,69 @@ class KVCacheStoreSendingThread(KVTransferThread):
class KVCacheStoreRecvingThread(KVTransferThread): class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, self,
block_size: int, tp_rank: int, dcp_size: int, m_store: Backend,
ready_event: threading.Event): token_database: ChunkedTokenDatabase,
super().__init__(m_store, block_size: int,
token_database, tp_rank: int,
block_size, dcp_size: int,
tp_rank, ready_event: threading.Event,
dcp_size, ):
ready_event, super().__init__(
name="KVCacheStoreRecvingThread") m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreRecvingThread"
)
def _handle_request(self, req_meta: ReqMeta): def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta.load_spec.token_len # type: ignore[union-attr] token_len = req_meta.load_spec.token_len # type: ignore[union-attr]
req_id = req_meta.req_id req_id = req_meta.req_id
mask_num = ( mask_num = (
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr] req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
// self.block_size * self.block_size) // self.block_size
* self.block_size
)
addr_list = [] addr_list = []
size_list = [] size_list = []
key_list = [] key_list = []
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes, mask_num):
token_len, req_meta.block_hashes, mask_num): addr, size, _ = self.token_database.prepare_value(start, end, req_meta.block_ids)
addr, size, _ = self.token_database.prepare_value(
start, end, req_meta.block_ids)
key_list.append(key.to_string()) key_list.append(key.to_string())
addr_list.append(addr) addr_list.append(addr)
size_list.append(size) size_list.append(size)
key_list_c = key_list[self.tp_rank % key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
len(key_list):] + key_list[:self.tp_rank % addr_list_c = addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
len(key_list)] size_list_c = size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c) self.m_store.get(key_list_c, addr_list_c, size_list_c)
self.set_finished_request(req_id) self.set_finished_request(req_id)
self.request_queue.task_done() self.request_queue.task_done()
class KVCacheStoreLayerSendingThread(KVTransferThread): class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, self,
block_size: int, tp_rank: int, dcp_size: int, put_step: int, m_store: Backend,
ready_event: threading.Event, num_layers: int): token_database: ChunkedTokenDatabase,
super().__init__(m_store, block_size: int,
token_database, tp_rank: int,
block_size, dcp_size: int,
tp_rank, put_step: int,
dcp_size, ready_event: threading.Event,
ready_event, num_layers: int,
name="KVCacheStoreLayerSendingThread") ):
super().__init__(
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerSendingThread"
)
self.final_layer_id = num_layers - 1 self.final_layer_id = num_layers - 1
self.put_step = put_step self.put_step = put_step
def add_request( # type: ignore[override] def add_request( # type: ignore[override]
self, req_meta: ReqMeta) -> torch.Tensor: self, req_meta: ReqMeta
) -> torch.Tensor:
self.request_queue.put(req_meta) self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override] def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta): self, req_meta: LasyerMultiBlockReqMeta
):
starts = req_meta.starts starts = req_meta.starts
ends = req_meta.ends ends = req_meta.ends
keys = req_meta.keys keys = req_meta.keys
@@ -272,9 +276,9 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
total_block = len(keys) total_block = len(keys)
is_last_chunk = req_meta.is_last_chunk is_last_chunk = req_meta.is_last_chunk
if not self.dcp_size > 1: if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step] starts = starts[self.tp_rank % self.put_step :: self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step] ends = ends[self.tp_rank % self.put_step :: self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step] keys = keys[self.tp_rank % self.put_step :: self.put_step]
if not keys: if not keys:
if is_last_chunk: if is_last_chunk:
@@ -300,7 +304,8 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
size_list = [] size_list = []
for index, key in enumerate(key_list): for index, key in enumerate(key_list):
addr, size = self.token_database.prepare_value_layer( addr, size = self.token_database.prepare_value_layer(
starts[index], ends[index], req_meta.block_ids, layer_id) starts[index], ends[index], req_meta.block_ids, layer_id
)
addr_list.append(addr) addr_list.append(addr)
size_list.append(size) size_list.append(size)
@@ -313,8 +318,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
self.request_queue.task_done() self.request_queue.task_done()
logger.info( logger.info(
"Storing KV cache for %d out of %d blocks " "Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
"(skip_block_num=%d) for request %s",
len(keys), len(keys),
total_block, total_block,
skip_block_num, skip_block_num,
@@ -323,44 +327,42 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
class KVCacheStoreLayerRecvingThread(KVTransferThread): class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, self,
block_size: int, tp_rank: int, dcp_size: int, m_store: Backend,
ready_event: threading.Event, get_event: threading.Event): token_database: ChunkedTokenDatabase,
super().__init__(m_store, block_size: int,
token_database, tp_rank: int,
block_size, dcp_size: int,
tp_rank, ready_event: threading.Event,
dcp_size, get_event: threading.Event,
ready_event, ):
name="KVCacheStoreLayerRecvingThread") super().__init__(
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerRecvingThread"
)
self.get_event = get_event self.get_event = get_event
def add_request( # type: ignore[override] def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: self, req_meta: LasyerMultiBlockReqMeta
) -> torch.Tensor:
self.request_queue.put(req_meta) self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override] def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta): self, req_meta: LasyerMultiBlockReqMeta
):
addr_list = [] addr_list = []
size_list = [] size_list = []
key_list = [] key_list = []
for index, key in enumerate(req_meta.keys): for index, key in enumerate(req_meta.keys):
addr, size = self.token_database.prepare_value_layer( addr, size = self.token_database.prepare_value_layer(
req_meta.starts[index], req_meta.ends[index], req_meta.starts[index], req_meta.ends[index], req_meta.block_ids, req_meta.layer_id
req_meta.block_ids, req_meta.layer_id) )
key_list.append(key.to_string()) key_list.append(key.to_string())
addr_list.append(addr) addr_list.append(addr)
size_list.append(size) size_list.append(size)
key_list_c = key_list[self.tp_rank % key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
len(key_list):] + key_list[:self.tp_rank % addr_list_c = addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
len(key_list)] size_list_c = size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c) self.m_store.get(key_list_c, addr_list_c, size_list_c)
self.request_queue.task_done() self.request_queue.task_done()

View File

@@ -1,10 +1,9 @@
from typing import Any, Optional from typing import Any
import vllm.envs as envs import vllm.envs as envs
import zmq import zmq
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import \ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
KVConnectorMetadata
from vllm.logger import logger from vllm.logger import logger
from vllm.utils.network_utils import make_zmq_socket from vllm.utils.network_utils import make_zmq_socket
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@@ -14,27 +13,29 @@ from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackEncoder from vllm.v1.serial_utils import MsgpackEncoder
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker) AscendConnectorMetadata,
LoadSpec,
ReqMeta,
RequestTracker,
)
class KVPoolScheduler: class KVPoolScheduler:
def __init__(self, vllm_config: "VllmConfig", use_layerwise): def __init__(self, vllm_config: "VllmConfig", use_layerwise):
self.use_layerwise = use_layerwise self.use_layerwise = use_layerwise
self.kv_role = vllm_config.kv_transfer_config.kv_role self.kv_role = vllm_config.kv_transfer_config.kv_role
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_load", False) "consumer_is_to_load", False
)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False) "consumer_is_to_put", False
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( )
"load_async", False) self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get("load_async", False)
self.client = LookupKeyClient(vllm_config) self.client = LookupKeyClient(vllm_config)
# request_id -> (vllm cached tokes, kvpool cached tokens) # request_id -> (vllm cached tokes, kvpool cached tokens)
self.load_specs: dict[str, LoadSpec] = {} self.load_specs: dict[str, LoadSpec] = {}
self.pcp_size = getattr(vllm_config.parallel_config, self.pcp_size = getattr(vllm_config.parallel_config, "prefill_context_parallel_size", 1)
"prefill_context_parallel_size", 1) self.dcp_size = getattr(vllm_config.parallel_config, "decode_context_parallel_size", 1)
self.dcp_size = getattr(vllm_config.parallel_config,
"decode_context_parallel_size", 1)
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1: if self.pcp_size > 1:
@@ -45,9 +46,9 @@ class KVPoolScheduler:
self._request_trackers: dict[str, RequestTracker] = {} self._request_trackers: dict[str, RequestTracker] = {}
self._preempted_req_ids: set[str] = set() self._preempted_req_ids: set[str] = set()
# Whether to discard partial chunks # Whether to discard partial chunks
self._discard_partial_chunks = ( self._discard_partial_chunks = vllm_config.kv_transfer_config.get_from_extra_config(
vllm_config.kv_transfer_config.get_from_extra_config( "discard_partial_chunks", True
"discard_partial_chunks", True)) )
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {} self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
self._unfinished_request_ids: set[str] = set() self._unfinished_request_ids: set[str] = set()
@@ -72,13 +73,11 @@ class KVPoolScheduler:
return 0, False return 0, False
if self._discard_partial_chunks: if self._discard_partial_chunks:
token_len = len(request.prompt_token_ids token_len = len(request.prompt_token_ids) // self._block_size * self._block_size
) // self._block_size * self._block_size
else: else:
token_len = len(request.prompt_token_ids) token_len = len(request.prompt_token_ids)
num_external_hit_tokens = self.client.lookup(token_len, num_external_hit_tokens = self.client.lookup(token_len, request.block_hashes)
request.block_hashes)
if num_external_hit_tokens == request.num_tokens: if num_external_hit_tokens == request.num_tokens:
num_external_hit_tokens -= 1 num_external_hit_tokens -= 1
@@ -107,9 +106,7 @@ class KVPoolScheduler:
return need_to_allocate, self.load_async and not self.use_layerwise return need_to_allocate, self.load_async and not self.use_layerwise
def update_state_after_alloc(self, request: "Request", def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
blocks: "KVCacheBlocks",
num_external_tokens: int):
""" """
Update KVConnector state after temporary buffer alloc. Update KVConnector state after temporary buffer alloc.
@@ -120,8 +117,7 @@ class KVPoolScheduler:
if num_external_tokens > 0: if num_external_tokens > 0:
local_block_ids = blocks.get_block_ids()[0] local_block_ids = blocks.get_block_ids()[0]
self._unfinished_requests[request.request_id] = (request, self._unfinished_requests[request.request_id] = (request, local_block_ids)
local_block_ids)
self._unfinished_request_ids.add(request.request_id) self._unfinished_request_ids.add(request.request_id)
if request.request_id not in self.load_specs: if request.request_id not in self.load_specs:
# No KV tokens from external KV cache, return # No KV tokens from external KV cache, return
@@ -133,18 +129,20 @@ class KVPoolScheduler:
return return
assert ( assert (
num_external_tokens > 0 and num_external_tokens num_external_tokens > 0
== self.load_specs[request.request_id].kvpool_cached_tokens - and num_external_tokens
self.load_specs[request.request_id].vllm_cached_tokens == self.load_specs[request.request_id].kvpool_cached_tokens
), (f"Mismatch in number of tokens: {num_external_tokens} vs " - self.load_specs[request.request_id].vllm_cached_tokens
), (
f"Mismatch in number of tokens: {num_external_tokens} vs "
f"{self.load_specs[request.request_id].kvpool_cached_tokens} - " f"{self.load_specs[request.request_id].kvpool_cached_tokens} - "
f"{self.load_specs[request.request_id].vllm_cached_tokens}" f"{self.load_specs[request.request_id].vllm_cached_tokens}"
f" for request {request.request_id}") f" for request {request.request_id}"
)
self.load_specs[request.request_id].can_load = True self.load_specs[request.request_id].can_load = True
def build_connector_meta( def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""Attach the connector metadata to the request object. """Attach the connector metadata to the request object.
This function should NOT modify other fields in the scheduler_output This function should NOT modify other fields in the scheduler_output
@@ -155,14 +153,13 @@ class KVPoolScheduler:
scheduler_output (SchedulerOutput): the scheduler output object. scheduler_output (SchedulerOutput): the scheduler output object.
""" """
force_skip_save = (self.kv_role == "kv_consumer" force_skip_save = self.kv_role == "kv_consumer" and not self.consumer_is_to_put
and not self.consumer_is_to_put)
for finished_req_id in scheduler_output.finished_req_ids: for finished_req_id in scheduler_output.finished_req_ids:
self._request_trackers.pop(finished_req_id, None) self._request_trackers.pop(finished_req_id, None)
self._unfinished_requests.pop(finished_req_id, None) self._unfinished_requests.pop(finished_req_id, None)
self._unfinished_request_ids.discard(finished_req_id) self._unfinished_request_ids.discard(finished_req_id)
for req_id in scheduler_output.preempted_req_ids: for req_id in scheduler_output.preempted_req_ids:
self._preempted_req_ids.update(scheduler_output.preempted_req_ids) self._preempted_req_ids.update(scheduler_output.preempted_req_ids)
self._request_trackers.pop(req_id, None) self._request_trackers.pop(req_id, None)
@@ -173,9 +170,7 @@ class KVPoolScheduler:
for request in scheduler_output.scheduled_new_reqs: for request in scheduler_output.scheduled_new_reqs:
# Right now, we only load KV for new requests # Right now, we only load KV for new requests
load_spec = self.load_specs.pop(request.req_id, None) load_spec = self.load_specs.pop(request.req_id, None)
num_tokens_to_compute = ( num_tokens_to_compute = request.num_computed_tokens + scheduler_output.num_scheduled_tokens[request.req_id]
request.num_computed_tokens +
scheduler_output.num_scheduled_tokens[request.req_id])
request_tuple = self._unfinished_requests.get(request.req_id) request_tuple = self._unfinished_requests.get(request.req_id)
request_real = request_tuple[0] # type: ignore[index] request_real = request_tuple[0] # type: ignore[index]
if not isinstance(request.block_ids[0], list): if not isinstance(request.block_ids[0], list):
@@ -183,25 +178,25 @@ class KVPoolScheduler:
else: else:
unfolded_block_ids = request.block_ids[0].copy() unfolded_block_ids = request.block_ids[0].copy()
request_tracker = RequestTracker( request_tracker = RequestTracker(
req_id=request.req_id, req_id=request.req_id,
token_len=num_tokens_to_compute, token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids, allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0, num_saved_tokens=0,
) )
self._request_trackers[request.req_id] = request_tracker self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = ((len(request.prompt_token_ids) // last_chunk_tokens_num = (
self._block_size * self._block_size) (len(request.prompt_token_ids) // self._block_size * self._block_size)
if self._discard_partial_chunks else len( if self._discard_partial_chunks
request.prompt_token_ids)) else len(request.prompt_token_ids)
)
req_meta = ReqMeta.from_request_tracker( req_meta = ReqMeta.from_request_tracker(
request_tracker, request_tracker,
self._block_size, self._block_size,
load_spec=load_spec, load_spec=load_spec,
skip_save=force_skip_save, skip_save=force_skip_save,
block_hashes=request_real.block_hashes, block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks, discard_partial_chunks=self._discard_partial_chunks,
) )
if req_meta is not None: if req_meta is not None:
@@ -224,8 +219,8 @@ class KVPoolScheduler:
request_tuple = self._unfinished_requests.get(req_id) request_tuple = self._unfinished_requests.get(req_id)
request_real = request_tuple[0] # type: ignore[index] request_real = request_tuple[0] # type: ignore[index]
num_tokens_to_compute = ( num_tokens_to_compute = (
request_real.num_computed_tokens + request_real.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]
scheduler_output.num_scheduled_tokens[req_id]) )
request_tracker = RequestTracker( request_tracker = RequestTracker(
req_id=req_id, req_id=req_id,
token_len=num_tokens_to_compute, token_len=num_tokens_to_compute,
@@ -233,21 +228,21 @@ class KVPoolScheduler:
num_saved_tokens=0, num_saved_tokens=0,
) )
self._request_trackers[req_id] = request_tracker self._request_trackers[req_id] = request_tracker
last_chunk_tokens_num = ((len(request_real.prompt_token_ids) // last_chunk_tokens_num = (
self._block_size * self._block_size) (len(request_real.prompt_token_ids) // self._block_size * self._block_size)
if self._discard_partial_chunks else len( if self._discard_partial_chunks
request_real.prompt_token_ids)) else len(request_real.prompt_token_ids)
)
req_meta = ReqMeta.from_request_tracker( req_meta = ReqMeta.from_request_tracker(
request_tracker, request_tracker,
self._block_size, self._block_size,
load_spec=load_spec, load_spec=load_spec,
skip_save=force_skip_save, skip_save=force_skip_save,
block_hashes=request_real.block_hashes, block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks, discard_partial_chunks=self._discard_partial_chunks,
) )
# decode/chunked request # decode/chunked request
else: else:
request_tracker = self._request_trackers[req_id] request_tracker = self._request_trackers[req_id]
@@ -256,48 +251,44 @@ class KVPoolScheduler:
if req_tuple: if req_tuple:
request = req_tuple[0] request = req_tuple[0]
num_current_tokens = request_tracker.token_len num_current_tokens = request_tracker.token_len
new_token_ids = request.all_token_ids[ new_token_ids = request.all_token_ids[num_current_tokens : num_current_tokens + num_new_tokens]
num_current_tokens:num_current_tokens + num_new_tokens]
request_tracker.token_len += len(new_token_ids) request_tracker.token_len += len(new_token_ids)
else: else:
raise ValueError( raise ValueError(
f"Request {req_id} is not in _unfinished_requests, " f"Request {req_id} is not in _unfinished_requests, but it is scheduled to be cached"
f"but it is scheduled to be cached") )
num_computed_token = cached_reqs.num_computed_tokens[i] num_computed_token = cached_reqs.num_computed_tokens[i]
if num_computed_token >= len(request.prompt_token_ids): if num_computed_token >= len(request.prompt_token_ids):
continue continue
request_tracker.update(new_block_ids) request_tracker.update(new_block_ids)
last_chunk_tokens_num = ((len(request.prompt_token_ids) // last_chunk_tokens_num = (
self._block_size * self._block_size) (len(request.prompt_token_ids) // self._block_size * self._block_size)
if self._discard_partial_chunks else if self._discard_partial_chunks
len(request.prompt_token_ids)) else len(request.prompt_token_ids)
)
req_meta = ReqMeta.from_request_tracker( req_meta = ReqMeta.from_request_tracker(
request_tracker, request_tracker,
self._block_size, self._block_size,
load_spec=None, load_spec=None,
skip_save=force_skip_save, skip_save=force_skip_save,
block_hashes=request.block_hashes, block_hashes=request.block_hashes,
is_last_chunk=request_tracker.token_len is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks, discard_partial_chunks=self._discard_partial_chunks,
) )
if req_meta is not None: if req_meta is not None:
meta.add_request(req_meta) meta.add_request(req_meta)
request_ids = [ request_ids = [req.req_id for req in scheduler_output.scheduled_new_reqs]
req.req_id for req in scheduler_output.scheduled_new_reqs for request_id, (request, block_ids) in self._unfinished_requests.items():
]
for request_id, (request,
block_ids) in self._unfinished_requests.items():
if request_id not in request_ids and request_id not in cached_reqs.req_ids: if request_id not in request_ids and request_id not in cached_reqs.req_ids:
load_spec = self.load_specs.pop(request_id, None) load_spec = self.load_specs.pop(request_id, None)
if not load_spec: if not load_spec:
continue continue
num_tokens_to_compute = load_spec.kvpool_cached_tokens num_tokens_to_compute = load_spec.kvpool_cached_tokens
if (num_tokens_to_compute % self._block_size if (num_tokens_to_compute % self._block_size != 0) and (
!= 0) and (num_tokens_to_compute num_tokens_to_compute == len(request.prompt_token_ids) - 1
== len(request.prompt_token_ids) - 1): ):
num_tokens_to_compute = num_tokens_to_compute + 1 num_tokens_to_compute = num_tokens_to_compute + 1
request_tracker = RequestTracker( request_tracker = RequestTracker(
req_id=request_id, req_id=request_id,
@@ -324,7 +315,7 @@ class KVPoolScheduler:
self, self,
request: "Request", request: "Request",
block_ids: list[int], block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]: ) -> tuple[bool, dict[str, Any] | None]:
""" """
Once a request is finished, determine whether request blocks Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later. should be freed now or will be sent asynchronously and freed later.
@@ -336,13 +327,11 @@ class KVPoolScheduler:
return False, None return False, None
delay_free_blocks = len(block_ids) > 0 delay_free_blocks = len(block_ids) > 0
if delay_free_blocks: if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s", logger.info("Delaying free of %d blocks for request %s", len(block_ids), request.request_id)
len(block_ids), request.request_id)
return delay_free_blocks, None return delay_free_blocks, None
class LookupKeyClient: class LookupKeyClient:
def __init__(self, vllm_config: "VllmConfig"): def __init__(self, vllm_config: "VllmConfig"):
self.encoder = MsgpackEncoder() self.encoder = MsgpackEncoder()
self.ctx = zmq.Context() # type: ignore[attr-defined] self.ctx = zmq.Context() # type: ignore[attr-defined]

View File

@@ -1,37 +1,45 @@
import math import math
import threading import threading
from typing import Dict, Generator, Optional, Type from collections.abc import Callable, Generator
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_decode_context_model_parallel_rank, from vllm.distributed import (
get_decode_context_model_parallel_world_size, get_decode_context_model_parallel_rank,
get_pcp_group, get_tensor_model_parallel_rank, get_decode_context_model_parallel_world_size,
get_tensor_model_parallel_world_size) get_pcp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import logger from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.kv_cache_utils import BlockHash
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
Backend from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import MemcacheBackend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import \ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import MooncakeBackend
MemcacheBackend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import \
MooncakeBackend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata, AscendConnectorMetadata,
LasyerMultiBlockReqMeta, ReqMeta) ChunkedTokenDatabase,
KeyMetadata,
LasyerMultiBlockReqMeta,
ReqMeta,
)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import ( from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreLayerRecvingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread,
KVCacheStoreSendingThread,
KVTransferThread,
)
backend_map: Dict[str, Type[Backend]] = { backend_map: dict[str, Callable[..., Backend]] = {
"mooncake": MooncakeBackend, "mooncake": MooncakeBackend,
"memcache": MemcacheBackend, "memcache": MemcacheBackend,
} }
class KVPoolWorker: class KVPoolWorker:
#The main class for the cache engine. # The main class for the cache engine.
def __init__( def __init__(
self, self,
@@ -42,9 +50,7 @@ class KVPoolWorker:
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.dp_rank = parallel_config.data_parallel_rank self.dp_rank = parallel_config.data_parallel_rank
self.use_mla = False self.use_mla = False
if (hasattr(model_config, "use_mla") if hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla:
and isinstance(model_config.use_mla, bool)
and model_config.use_mla):
self.use_mla = True self.use_mla = True
self.use_layerwise = use_layerwize self.use_layerwise = use_layerwize
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
@@ -53,19 +59,16 @@ class KVPoolWorker:
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size
self.pcp_size = get_pcp_group().world_size self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group( self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
).rank_in_group if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size() self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank( self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
) if self.dcp_size > 1 else 0
self.kv_role = vllm_config.kv_transfer_config.kv_role self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get("load_async", False)
"load_async", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False) "consumer_is_to_put", False
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get( )
"backend", "mooncake") self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get("backend", "mooncake")
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1: if self.pcp_size > 1:
@@ -88,7 +91,7 @@ class KVPoolWorker:
self.put_step = 1 self.put_step = 1
self.metadata = KeyMetadata( self.metadata = KeyMetadata(
model_config.model.rstrip('/').split('/')[-1], model_config.model.rstrip("/").split("/")[-1],
self.head_or_tp_rank, self.head_or_tp_rank,
self.pcp_rank, self.pcp_rank,
self.dcp_rank, self.dcp_rank,
@@ -99,40 +102,28 @@ class KVPoolWorker:
if self.kv_role == "kv_consumer" and self.consumer_is_to_put: if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
num_hidden_layers = model_config.hf_text_config.num_hidden_layers num_hidden_layers = model_config.hf_text_config.num_hidden_layers
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get( partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_layer_partition", None) "prefill_pp_layer_partition", None
prefill_pp_size = int( )
vllm_config.kv_transfer_config.kv_connector_extra_config.get( prefill_pp_size = int(vllm_config.kv_transfer_config.kv_connector_extra_config.get("prefill_pp_size", 1))
"prefill_pp_size", 1))
if partition_list_str is not None: if partition_list_str is not None:
try: try:
partitions = [ partitions = [int(layer) for layer in partition_list_str.split(",")]
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err: except ValueError as err:
raise ValueError("Invalid partition string: {}".format( raise ValueError("Invalid partition string: {}".format(partition_list_str)) from err
partition_list_str)) from err
if len(partitions) != prefill_pp_size: if len(partitions) != prefill_pp_size:
raise ValueError( raise ValueError(f"{len(partitions)=} does not match {prefill_pp_size=}.")
f"{len(partitions)=} does not match {prefill_pp_size=}."
)
if sum(partitions) != num_hidden_layers: if sum(partitions) != num_hidden_layers:
raise ValueError( raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
f"{sum(partitions)=} does not match {num_hidden_layers=}."
)
else: else:
layers_per_partition = num_hidden_layers // prefill_pp_size layers_per_partition = num_hidden_layers // prefill_pp_size
partitions = [ partitions = [layers_per_partition for _ in range(prefill_pp_size)]
layers_per_partition for _ in range(prefill_pp_size)
]
if remaining_layers := num_hidden_layers % prefill_pp_size: if remaining_layers := num_hidden_layers % prefill_pp_size:
for i in range(2, remaining_layers + 2): for i in range(2, remaining_layers + 2):
partitions[-i] += 1 partitions[-i] += 1
self.token_database = ChunkedTokenDatabase(self.metadata, self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions)
self.block_size,
self.use_mla, partitions)
real_backend = backend_map.get(self.backend.lower()) real_backend = backend_map.get(self.backend.lower())
@@ -142,10 +133,11 @@ class KVPoolWorker:
self.put_step = 1 self.put_step = 1
self.m_store = real_backend( # type: ignore[misc] self.m_store = real_backend( # type: ignore[misc]
parallel_config) parallel_config
)
self.kv_send_thread: Optional[KVTransferThread] = None self.kv_send_thread: KVTransferThread | None = None
self.kv_recv_thread: Optional[KVTransferThread] = None self.kv_recv_thread: KVTransferThread | None = None
self.finished_store_req: set[str] = set() self.finished_store_req: set[str] = set()
@@ -162,11 +154,14 @@ class KVPoolWorker:
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [ self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm), first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe) first_kv_cache[1].element_size() * math.prod(block_shape_pe),
] ]
logger.info( logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe) self.num_blocks,
block_shape_norm,
block_shape_pe,
)
else: else:
# [num_block, block_size, num_head, hidden_dim] # [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0] self.num_blocks = first_kv_cache.shape[0]
@@ -174,11 +169,9 @@ class KVPoolWorker:
block_rank = 3 # [block_size, kv_heads, head_dim] block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:] block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)] self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s", logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape)
self.use_mla, first_kv_cache.shape)
self.kv_caches = kv_caches self.kv_caches = kv_caches
self.kv_caches_base_addr = [] self.kv_caches_base_addr = []
@@ -194,8 +187,7 @@ class KVPoolWorker:
ptrs.append(base_addr) ptrs.append(base_addr)
lengths.append(region_len) lengths.append(region_len)
else: else:
cache_list = [cache_or_caches cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
] if self.use_mla else cache_or_caches
for cache in cache_list: for cache in cache_list:
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr.append(base_addr)
@@ -208,33 +200,50 @@ class KVPoolWorker:
if self.use_layerwise: if self.use_layerwise:
self.get_event = threading.Event() self.get_event = threading.Event()
if self.kv_role in ['kv_producer', 'kv_both']: if self.kv_role in ["kv_producer", "kv_both"]:
ready_event_sending = threading.Event() ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread( self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.m_store, self.token_database, self.block_size, self.m_store,
self.tp_rank, self.dcp_size, self.put_step, self.token_database,
ready_event_sending, self.num_layers) self.block_size,
self.tp_rank,
self.dcp_size,
self.put_step,
ready_event_sending,
self.num_layers,
)
self.kv_send_thread.start() self.kv_send_thread.start()
ready_event = threading.Event() ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread( self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.m_store, self.token_database, self.block_size, self.m_store,
self.tp_rank, self.dcp_size, ready_event, self.get_event) self.token_database,
self.block_size,
self.tp_rank,
self.dcp_size,
ready_event,
self.get_event,
)
self.kv_recv_thread.start() self.kv_recv_thread.start()
ready_event.wait() ready_event.wait()
else: else:
if self.kv_role in ['kv_producer', 'kv_both' if self.kv_role in ["kv_producer", "kv_both"] or self.consumer_is_to_put:
] or self.consumer_is_to_put:
ready_event_sending = threading.Event() ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread( self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.block_size, self.m_store,
self.tp_rank, self.dcp_size, self.put_step, self.kv_role, self.token_database,
ready_event_sending) self.block_size,
self.tp_rank,
self.dcp_size,
self.put_step,
self.kv_role,
ready_event_sending,
)
self.kv_send_thread.start() self.kv_send_thread.start()
if self.load_async: if self.load_async:
ready_event = threading.Event() ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread( self.kv_recv_thread = KVCacheStoreRecvingThread(
self.m_store, self.token_database, self.block_size, self.m_store, self.token_database, self.block_size, self.tp_rank, self.dcp_size, ready_event
self.tp_rank, self.dcp_size, ready_event) )
self.kv_recv_thread.start() self.kv_recv_thread.start()
ready_event.wait() ready_event.wait()
@@ -243,12 +252,12 @@ class KVPoolWorker:
self.layerwise_retrievers = [] self.layerwise_retrievers = []
for request in metadata.requests: for request in metadata.requests:
load_spec = request.load_spec load_spec = request.load_spec
if load_spec is None or not load_spec.can_load: #load =0 if load_spec is None or not load_spec.can_load: # load =0
continue continue
token_len = request.token_len_chunk token_len = request.token_len_chunk
if (load_spec.kvpool_cached_tokens % self.block_size if (load_spec.kvpool_cached_tokens % self.block_size != 0) and (
!= 0) and (load_spec.kvpool_cached_tokens load_spec.kvpool_cached_tokens == token_len - 1
== token_len - 1): ):
token_len = request.load_spec.kvpool_cached_tokens + 1 token_len = request.load_spec.kvpool_cached_tokens + 1
else: else:
token_len = request.load_spec.kvpool_cached_tokens token_len = request.load_spec.kvpool_cached_tokens
@@ -260,30 +269,27 @@ class KVPoolWorker:
else: else:
if self.load_async: if self.load_async:
self.kv_recv_thread.add_request( # type: ignore[union-attr] self.kv_recv_thread.add_request( # type: ignore[union-attr]
request, ) request,
)
else: else:
addr_list = [] addr_list = []
size_list = [] size_list = []
key_list = [] key_list = []
mask_num = (request.load_spec.vllm_cached_tokens // mask_num = request.load_spec.vllm_cached_tokens // self.block_size * self.block_size
self.block_size * self.block_size)
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(
token_len, request.block_hashes, mask_num): token_len, request.block_hashes, mask_num
addr, size, _ = self.token_database.prepare_value( ):
start, end, request.block_ids) addr, size, _ = self.token_database.prepare_value(start, end, request.block_ids)
key_list.append(key.to_string()) key_list.append(key.to_string())
addr_list.append(addr) addr_list.append(addr)
size_list.append(size) size_list.append(size)
key_list_c = key_list[self.tp_rank % len( key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
key_list):] + key_list[:self.tp_rank % len(key_list)] addr_list_c = (
addr_list_c = addr_list[self.tp_rank % addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
len(addr_list )
):] + addr_list[:self.tp_rank % size_list_c = (
len(addr_list)] size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
size_list_c = size_list[self.tp_rank % )
len(size_list
):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c) self.m_store.get(key_list_c, addr_list_c, size_list_c)
def wait_for_layer_load(self) -> None: def wait_for_layer_load(self) -> None:
@@ -294,8 +300,7 @@ class KVPoolWorker:
num_retrieved_tokens = ret_token_mask.sum().item() num_retrieved_tokens = ret_token_mask.sum().item()
logger.debug(f"Retrieved {num_retrieved_tokens} tokens") logger.debug(f"Retrieved {num_retrieved_tokens} tokens")
def save_kv_layer(self, def save_kv_layer(self, connector_metadata: AscendConnectorMetadata) -> None:
connector_metadata: AscendConnectorMetadata) -> None:
if self.current_layer == 0: if self.current_layer == 0:
self.layerwise_storers = [] self.layerwise_storers = []
current_event = None current_event = None
@@ -336,15 +341,17 @@ class KVPoolWorker:
continue continue
request.current_event = current_event request.current_event = current_event
self.kv_send_thread.add_stored_request( # type: ignore[union-attr] self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
request.req_id) request.req_id
)
self.kv_send_thread.add_request( # type: ignore[union-attr] self.kv_send_thread.add_request( # type: ignore[union-attr]
request, ) request,
)
def retrieve_layer( def retrieve_layer(
self, self,
request: ReqMeta, request: ReqMeta,
) -> Generator[Optional[torch.Tensor], None, None]: ) -> Generator[torch.Tensor | None, None, None]:
""" """
Retrieve the KV cache in a layerwise manner. Retrieve the KV cache in a layerwise manner.
@@ -359,12 +366,14 @@ class KVPoolWorker:
return: A generator that yields Optional[torch.Tensor]. The tensor will return: A generator that yields Optional[torch.Tensor]. The tensor will
be the boolean mask indicating which tokens are retrieved and will be the boolean mask indicating which tokens are retrieved and will
only be returned in the last iteration. only be returned in the last iteration.
""" """
token_len = request.token_len_chunk token_len = request.token_len_chunk
mask_num = ( mask_num = (
request.load_spec.vllm_cached_tokens # type: ignore[union-attr] request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
// self.block_size * self.block_size) // self.block_size
* self.block_size
)
num_required_tokens = token_len - mask_num num_required_tokens = token_len - mask_num
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu") ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
@@ -373,8 +382,7 @@ class KVPoolWorker:
ends = [] ends = []
keys = [] keys = []
first_flag = True first_flag = True
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(token_len, request.block_hashes, mask_num):
token_len, request.block_hashes, mask_num):
keys_multi_layer = key.split_layers(self.num_layers) keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start) starts.append(start)
ends.append(end) ends.append(end)
@@ -386,16 +394,16 @@ class KVPoolWorker:
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num] keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
for layer_id, keys_multi_chunk in enumerate(keys): for layer_id, keys_multi_chunk in enumerate(keys):
if not first_flag: if not first_flag:
is_finish = self.get_event.wait(timeout=3) #try---cache is_finish = self.get_event.wait(timeout=3) # try---cache
if not is_finish: if not is_finish:
logger.info("Layerwise get failed") logger.info("Layerwise get failed")
self.get_event.clear() self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(request.req_id, req_meta = LasyerMultiBlockReqMeta(
keys_multi_chunk, starts, request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id
ends, request.block_ids, )
layer_id)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type] req_meta
) # type: ignore[union-attr, call-arg, arg-type]
first_flag = False first_flag = False
yield None yield None
else: else:
@@ -405,16 +413,14 @@ class KVPoolWorker:
yield None yield None
retrieved_tokens = torch.sum(ret_mask) retrieved_tokens = torch.sum(ret_mask)
logger.debug(f"Retrieved {retrieved_tokens} " logger.debug(f"Retrieved {retrieved_tokens} out of {num_required_tokens} out of total {token_len} tokens")
f"out of {num_required_tokens} "
f"out of total {token_len} tokens")
yield ret_mask yield ret_mask
def store_layer( def store_layer(
self, self,
request: ReqMeta, request: ReqMeta,
current_event: Optional[torch.npu.Event], current_event: torch.npu.Event | None,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
""" """
Store the KV cache in a layerwise manner. Store the KV cache in a layerwise manner.
@@ -439,69 +445,88 @@ class KVPoolWorker:
starts = [] starts = []
ends = [] ends = []
keys = [] keys = []
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(request.token_len_chunk, request.block_hashes):
request.token_len_chunk, request.block_hashes):
keys_multi_layer = key.split_layers(self.num_layers) keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start) starts.append(start)
ends.append(end) ends.append(end)
keys.append(keys_multi_layer) #[block_num,layer_num] keys.append(keys_multi_layer) # [block_num,layer_num]
if keys: if keys:
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num] keys = [list(row) for row in zip(*keys)] # [layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys): for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(request.req_id, req_meta = LasyerMultiBlockReqMeta(
keys_multi_chunk, starts, request.req_id,
ends, request.block_ids, keys_multi_chunk,
layer_id, starts,
request.is_last_chunk, ends,
current_event) request.block_ids,
layer_id,
request.is_last_chunk,
current_event,
)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type] req_meta
) # type: ignore[union-attr, call-arg, arg-type]
yield yield
else: else:
for layer_id in range(self.num_layers): for layer_id in range(self.num_layers):
yield yield
def get_finished(self, def get_finished(self, finished_req_ids: set[str], meta: AscendConnectorMetadata) -> tuple[set[str], set[str]]:
finished_req_ids: set[str], meta:AscendConnectorMetadata) -> tuple[set[str], set[str]]:
done_sending = ( done_sending = (
self.get_and_clear_finished_requests( self.get_and_clear_finished_requests(
finished_req_ids, meta # type: ignore[union-attr] finished_req_ids,
) if self.kv_role in ['kv_producer', 'kv_both'] meta, # type: ignore[union-attr]
or self.consumer_is_to_put else set()) )
if self.kv_role in ["kv_producer", "kv_both"] or self.consumer_is_to_put
else set()
)
done_recving = ( done_recving = (
self.kv_recv_thread. self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr]
get_and_clear_finished_requests( # type: ignore[union-attr] )
) if self.load_async else set()) if self.load_async
else set()
)
logger.debug( logger.debug(
"Number of completed KV cache send requests: %d, receive " "Number of completed KV cache send requests: %d, receive requests: %d, tp_rank:%d",
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving), len(done_sending),
self.tp_rank) len(done_recving),
self.tp_rank,
)
return done_sending, done_recving return done_sending, done_recving
def get_and_clear_finished_requests(self, finished_req_ids, meta:AscendConnectorMetadata) -> set[str]: def get_and_clear_finished_requests(self, finished_req_ids, meta: AscendConnectorMetadata) -> set[str]:
finished_sending = set() finished_sending = set()
for req_id in meta.preempted_req_ids: for req_id in meta.preempted_req_ids:
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr] self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id) req_id
)
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr] for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
): ):
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr] if (
req_id] == 0 and req_id in self.finished_store_req: self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
req_id
]
== 0
and req_id in self.finished_store_req
):
self.finished_store_req.remove(req_id) self.finished_store_req.remove(req_id)
finished_sending.add(req_id) finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr] self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id) req_id
)
for req_id in finished_req_ids: for req_id in finished_req_ids:
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr] req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
req_id) req_id
)
if req_remain_jobs == 0: if req_remain_jobs == 0:
finished_sending.add(req_id) finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr] self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id) req_id
)
elif req_remain_jobs is not None: elif req_remain_jobs is not None:
self.finished_store_req.add(req_id) self.finished_store_req.add(req_id)
@@ -522,8 +547,7 @@ class KVPoolWorker:
keys = [] keys = []
try: try:
starts = [] starts = []
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(token_len, block_hashes):
token_len, block_hashes):
if use_layerwise: if use_layerwise:
keys_multi_layer = key.split_layers(self.num_layers) keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer: for item in keys_multi_layer:
@@ -560,8 +584,7 @@ class KVPoolWorker:
keys = [] keys = []
try: try:
starts = [] starts = []
for start, end, key in self.token_database.process_tokens( for start, end, key in self.token_database.process_tokens(token_len, block_hashes):
token_len, block_hashes):
if use_layerwise: if use_layerwise:
keys_multi_layer = key.split_layers(self.num_layers) keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer: for item in keys_multi_layer:
@@ -574,25 +597,25 @@ class KVPoolWorker:
for i in range(1, min(self.tp_size, self.num_kv_head)): for i in range(1, min(self.tp_size, self.num_kv_head)):
for item in keys: for item in keys:
new_str = item.replace( # type: ignore[attr-defined] new_str = item.replace( # type: ignore[attr-defined]
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1) "@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1
)
multi_tp_keys.append(new_str) multi_tp_keys.append(new_str)
for i in range(1, self.pp_size): for i in range(1, self.pp_size):
for item in keys: for item in keys:
new_str = item.replace( # type: ignore[attr-defined] new_str = item.replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{i}", 1) "@pp_rank:0", f"@pp_rank:{i}", 1
)
multi_tp_keys.append(new_str) multi_tp_keys.append(new_str)
res = self.m_store.exists( res = self.m_store.exists(multi_tp_keys) # type: ignore[assignment]
multi_tp_keys) # type: ignore[assignment]
num_block = len(keys) num_block = len(keys)
if use_layerwise: if use_layerwise:
res = self.check_all_layers_exists(res, self.num_layers) res = self.check_all_layers_exists(res, self.num_layers)
num_block = len(keys) // self.num_layers num_block = len(keys) // self.num_layers
multi_tp_values = [ multi_tp_values = [
res[i * num_block:(i + 1) * num_block] # type: ignore[index] res[i * num_block : (i + 1) * num_block] # type: ignore[index]
for i in range( for i in range(min(self.tp_size, self.num_kv_head) * self.pp_size)
min(self.tp_size, self.num_kv_head) * self.pp_size)
] ]
index = self.find_min_first_non_one_index(multi_tp_values) index = self.find_min_first_non_one_index(multi_tp_values)
if index != -1: if index != -1:
@@ -603,8 +626,7 @@ class KVPoolWorker:
return start return start
return end return end
def check_all_layers_exists(self, res: list[int], def check_all_layers_exists(self, res: list[int], num_layers: int) -> list[int]:
num_layers: int) -> list[int]:
total_chunks = len(res) // num_layers total_chunks = len(res) // num_layers
result = [] result = []
@@ -618,7 +640,6 @@ class KVPoolWorker:
def find_min_first_non_one_index(self, arr): def find_min_first_non_one_index(self, arr):
try: try:
return min(idx for row in arr for idx, val in enumerate(row) return min(idx for row in arr for idx, val in enumerate(row) if val != 1)
if val != 1)
except ValueError: except ValueError:
return -1 return -1

View File

@@ -1,20 +1,17 @@
import time import time
from collections import defaultdict from collections import defaultdict
from typing import Optional
from vllm.logger import logger from vllm.logger import logger
from vllm.utils.hashing import sha256 from vllm.utils.hashing import sha256
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock) from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import \ from vllm.v1.core.single_type_kv_cache_manager import get_manager_for_kv_cache_spec
get_manager_for_kv_cache_spec
from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.metrics.stats import (PrefixCacheStats, CachingMetrics) from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
class CPUCacheStats: class CPUCacheStats:
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False): def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
self.enable_prefix_caching = enable_prefix_caching self.enable_prefix_caching = enable_prefix_caching
self.log_stats = log_stats self.log_stats = log_stats
@@ -27,10 +24,9 @@ class CPUCacheStats:
# Log the prefix cache hit rate every 10 seconds. # Log the prefix cache hit rate every 10 seconds.
if current_time_sec - self.time_sec >= 10: if current_time_sec - self.time_sec >= 10:
self.time_sec = current_time_sec self.time_sec = current_time_sec
logger.info("CPU Prefix cache hit rate: %.1f%%", logger.info("CPU Prefix cache hit rate: %.1f%%", self.cpu_prefix_cache_metrics.hit_rate * 100)
self.cpu_prefix_cache_metrics.hit_rate * 100)
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: def make_prefix_cache_stats(self) -> PrefixCacheStats | None:
"""Get (and reset) the prefix cache stats. """Get (and reset) the prefix cache stats.
Returns: Returns:
The current prefix caching stats, or None if logging is disabled. The current prefix caching stats, or None if logging is disabled.
@@ -57,7 +53,6 @@ class CPUCacheStats:
class CPUKVCacheManager: class CPUKVCacheManager:
def __init__( def __init__(
self, self,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
@@ -70,30 +65,26 @@ class CPUKVCacheManager:
self.num_cpu_blocks = num_cpu_blocks self.num_cpu_blocks = num_cpu_blocks
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
self.use_eagle = use_eagle self.use_eagle = use_eagle
self.block_pool = BlockPool(self.num_cpu_blocks, True, self.block_pool = BlockPool(self.num_cpu_blocks, True, enable_kv_cache_events)
enable_kv_cache_events)
self.single_type_manager = get_manager_for_kv_cache_spec( self.single_type_manager = get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec, kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_group_id=0, kv_cache_group_id=0,
) )
# Record kv block hashes, avoid redundant computation. # Record kv block hashes, avoid redundant computation.
self.req_to_block_hashes: defaultdict[ self.req_to_block_hashes: defaultdict[str, list[BlockHash]] = defaultdict(list)
str, list[BlockHash]] = defaultdict(list)
# Record blocks touched in get_matched_num_and_touch(). # Record blocks touched in get_matched_num_and_touch().
self.req_to_computed_blocks: defaultdict[ self.req_to_computed_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list)
str, list[KVCacheBlock]] = defaultdict(list)
# Record the request that failed to allocate. # Record the request that failed to allocate.
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool) self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int) self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True, self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True, log_stats=True)
log_stats=True)
# Record request that will be free after finish sending # Record request that will be free after finish sending
self.req_to_free: defaultdict[str, Request] = defaultdict(Request) self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]: def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
# When the request requires prompt logprobs, we skip prefix caching. # When the request requires prompt logprobs, we skip prefix caching.
if (request.sampling_params.prompt_logprobs is not None): if request.sampling_params.prompt_logprobs is not None:
return 0, False return 0, False
request_id = request.request_id request_id = request.request_id
# The block hashes for the request may already be computed # The block hashes for the request may already be computed
@@ -119,10 +110,8 @@ class CPUKVCacheManager:
# cup prefix cache status set and log # cup prefix cache status set and log
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
self.cpu_cache_stats.set_cache_stats(request.num_tokens, self.cpu_cache_stats.set_cache_stats(request.num_tokens, num_computed_tokens)
num_computed_tokens) self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(self.cpu_cache_stats.prefix_cache_stats)
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
self.cpu_cache_stats.prefix_cache_stats)
self.cpu_cache_stats.log() self.cpu_cache_stats.log()
return num_computed_tokens, False return num_computed_tokens, False
@@ -130,12 +119,10 @@ class CPUKVCacheManager:
def _release_ahead_touch(self, request_id: str): def _release_ahead_touch(self, request_id: str):
computed_blocks = self.req_to_computed_blocks[request_id] computed_blocks = self.req_to_computed_blocks[request_id]
if computed_blocks: if computed_blocks:
self.single_type_manager.block_pool.free_blocks( self.single_type_manager.block_pool.free_blocks(reversed(computed_blocks))
reversed(computed_blocks))
self.req_to_computed_blocks.pop(request_id, None) self.req_to_computed_blocks.pop(request_id, None)
def allocate_slots(self, req_to_num_tokens: dict[str, int], def allocate_slots(self, req_to_num_tokens: dict[str, int], unallocated_req_ids: set[str]) -> dict[str, list[int]]:
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
for request_id in unallocated_req_ids: for request_id in unallocated_req_ids:
self._free_slots(request_id) self._free_slots(request_id)
req_to_new_blocks = {} req_to_new_blocks = {}
@@ -143,44 +130,34 @@ class CPUKVCacheManager:
if self.req_failed_to_allocate[request_id]: if self.req_failed_to_allocate[request_id]:
continue continue
new_computed_blocks = self.req_to_computed_blocks[request_id] new_computed_blocks = self.req_to_computed_blocks[request_id]
num_blocks_to_allocate = ( num_blocks_to_allocate = self.single_type_manager.get_num_blocks_to_allocate(
self.single_type_manager.get_num_blocks_to_allocate( request_id=request_id,
request_id=request_id, num_tokens=num_tokens,
num_tokens=num_tokens, new_computed_blocks=new_computed_blocks,
new_computed_blocks=new_computed_blocks, )
))
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
self._release_ahead_touch(request_id) self._release_ahead_touch(request_id)
self.req_failed_to_allocate[request_id] = True self.req_failed_to_allocate[request_id] = True
continue continue
# Append the new computed blocks to the request blocks until now to # Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated. # avoid the case where the new blocks cannot be allocated.
self.single_type_manager.save_new_computed_blocks( self.single_type_manager.save_new_computed_blocks(request_id, new_computed_blocks)
request_id, new_computed_blocks)
# Allocate new blocks but do not cache now. # Allocate new blocks but do not cache now.
new_blocks = self.single_type_manager.allocate_new_blocks( new_blocks = self.single_type_manager.allocate_new_blocks(request_id, num_tokens)
request_id, num_tokens)
self.req_to_num_tokens[request_id] = num_tokens self.req_to_num_tokens[request_id] = num_tokens
# No need to release ref_cnt because we use officially. # No need to release ref_cnt because we use officially.
self.req_to_computed_blocks.pop(request_id, None) self.req_to_computed_blocks.pop(request_id, None)
req_to_new_blocks[request_id] = [ req_to_new_blocks[request_id] = [block.block_id for block in new_computed_blocks + new_blocks]
block.block_id for block in new_computed_blocks + new_blocks
]
return req_to_new_blocks return req_to_new_blocks
def record_request_cache_and_free_slots(self, request: Request): def record_request_cache_and_free_slots(self, request: Request):
logger.debug( logger.debug(f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager")
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
)
self.req_to_free[request.request_id] = request self.req_to_free[request.request_id] = request
def cache_and_free_slots(self, request_id: str): def cache_and_free_slots(self, request_id: str):
logger.debug( logger.debug(f"Cache and free slots for request {request_id} in cpu_kv_cache_manager")
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
)
if request_id not in self.req_to_free: if request_id not in self.req_to_free:
logger.Error( logger.Error(f"request {request_id} not in req_to_free, maybe bug!")
f"request {request_id} not in req_to_free, maybe bug!")
return return
request = self.req_to_free[request_id] request = self.req_to_free[request_id]
if not self.req_failed_to_allocate[request_id]: if not self.req_failed_to_allocate[request_id]:
@@ -189,8 +166,7 @@ class CPUKVCacheManager:
self.req_to_num_tokens[request_id], self.req_to_num_tokens[request_id],
) )
self._free_slots(request_id) self._free_slots(request_id)
logger.debug( logger.debug(f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
del self.req_to_free[request_id] del self.req_to_free[request_id]
def _free_slots(self, request_id: str): def _free_slots(self, request_id: str):

View File

@@ -5,15 +5,15 @@ import queue
import threading import threading
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Sequence from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from vllm.attention.layer import Attention, MLAAttention from vllm.attention.layer import Attention, MLAAttention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
@@ -23,12 +23,14 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.metadata import ( from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.metadata import (
MetadataServer, MetadataServerProc, MLAConfig) MetadataServer,
MetadataServerProc,
MLAConfig,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.attention.backend import AttentionMetadata #type: ignore
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
@@ -59,20 +61,15 @@ class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
class CPUOffloadingConnector(KVConnectorBase_V1): class CPUOffloadingConnector(KVConnectorBase_V1):
def __init__(
def __init__(self, self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None
vllm_config: VllmConfig, ):
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None):
self._connector_metadata = CPUOffloadingConnectorMetadata(requests={}, finished_req_ids=set()) self._connector_metadata = CPUOffloadingConnectorMetadata(requests={}, finished_req_ids=set())
if not vllm_config.cache_config.enable_prefix_caching: if not vllm_config.cache_config.enable_prefix_caching:
self.connector_scheduler: Optional[ self.connector_scheduler: CPUOffloadingConnectorScheduler | None = None
CPUOffloadingConnectorScheduler] = None self.connector_worker: CPUOffloadingConnectorWorker | None = None
self.connector_worker: Optional[
CPUOffloadingConnectorWorker] = None
elif role == KVConnectorRole.SCHEDULER: elif role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = CPUOffloadingConnectorScheduler( self.connector_scheduler = CPUOffloadingConnectorScheduler(vllm_config)
vllm_config)
self.connector_worker = None self.connector_worker = None
elif role == KVConnectorRole.WORKER: elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None self.connector_scheduler = None
@@ -82,11 +79,9 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
# Worker-side methods # Worker-side methods
# ============================== # ==============================
def bind_connector_metadata( def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
self, connector_metadata: KVConnectorMetadata) -> None:
if self.connector_worker is not None: if self.connector_worker is not None:
assert isinstance(connector_metadata, assert isinstance(connector_metadata, CPUOffloadingConnectorMetadata)
CPUOffloadingConnectorMetadata)
self.connector_worker.bind_connector_metadata(connector_metadata) self.connector_worker.bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None: def clear_connector_metadata(self) -> None:
@@ -97,8 +92,7 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
if self.connector_worker is not None: if self.connector_worker is not None:
self.connector_worker.register_kv_caches(kv_caches) self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext", def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
**kwargs) -> None:
if self.connector_worker is not None: if self.connector_worker is not None:
self.connector_worker.start_load_kv() self.connector_worker.start_load_kv()
@@ -106,53 +100,42 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
if self.connector_worker is not None: if self.connector_worker is not None:
self.connector_worker.wait_for_layer_load() self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, def save_kv_layer(
attn_metadata: "AttentionMetadata", **kwargs) -> None: self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
) -> None:
pass pass
def wait_for_save(self): def wait_for_save(self):
pass pass
def get_finished( def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str] | None, set[str] | None]:
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
assert self.connector_worker is not None assert self.connector_worker is not None
return self.connector_worker.get_finished(), None return self.connector_worker.get_finished(), None
# Scheduler-side methods # Scheduler-side methods
# ============================== # ==============================
def get_num_new_matched_tokens( def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
if self.connector_scheduler is not None: if self.connector_scheduler is not None:
return self.connector_scheduler.get_num_new_matched_tokens( return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens)
request, num_computed_tokens)
return 0, False return 0, False
def update_state_after_alloc(self, request: "Request", def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
blocks: "KVCacheBlocks",
num_external_tokens: int):
if self.connector_scheduler is not None: if self.connector_scheduler is not None:
return self.connector_scheduler.update_state_after_alloc(request) return self.connector_scheduler.update_state_after_alloc(request)
def build_connector_meta( def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
if self.connector_scheduler is not None: if self.connector_scheduler is not None:
return self.connector_scheduler.build_connector_meta( return self.connector_scheduler.build_connector_meta(scheduler_output)
scheduler_output)
return KVConnectorMetadata() return KVConnectorMetadata()
def request_finished( def request_finished(self, request: "Request", block_ids: list[int]) -> tuple[bool, dict[str, Any] | None]:
self, request: "Request",
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
if self.connector_scheduler is not None: if self.connector_scheduler is not None:
self.connector_scheduler.request_finished(request) self.connector_scheduler.request_finished(request)
return True, None return True, None
class CPUOffloadingConnectorScheduler: class CPUOffloadingConnectorScheduler:
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorScheduler") logger.info("init CPUOffloadingConnectorScheduler")
self.vllm_config = vllm_config self.vllm_config = vllm_config
@@ -165,22 +148,17 @@ class CPUOffloadingConnectorScheduler:
self.zmq_rpc_client = MetadataServer.ZMQRPCClient() self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
self.zmq_rpc_client.call("post_init") self.zmq_rpc_client.call("post_init")
if vllm_config.kv_transfer_config is not None: if vllm_config.kv_transfer_config is not None:
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config( self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config("swap_in_threshold", 0)
"swap_in_threshold", 0)
else: else:
self.swap_in_threshold = 0 self.swap_in_threshold = 0
logger.info(f"swap_in_threshold: {self.swap_in_threshold}") logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
def get_num_new_matched_tokens( def get_num_new_matched_tokens(self, ori_request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
self, ori_request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
request = copy.deepcopy(ori_request) request = copy.deepcopy(ori_request)
request.get_hash_new_full_blocks = None request.get_hash_new_full_blocks = None
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call( num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call("get_matched_num_and_touch", request)
"get_matched_num_and_touch", request)
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
self.num_cpu_computed_tokens[ self.num_cpu_computed_tokens[request.request_id] = num_cpu_computed_tokens
request.request_id] = num_cpu_computed_tokens
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold: if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
return num_cpu_computed_tokens - num_computed_tokens, load_async return num_cpu_computed_tokens - num_computed_tokens, load_async
else: else:
@@ -189,29 +167,22 @@ class CPUOffloadingConnectorScheduler:
def update_state_after_alloc(self, request: "Request"): def update_state_after_alloc(self, request: "Request"):
self.allocated_req_ids.add(request.request_id) self.allocated_req_ids.add(request.request_id)
def build_connector_meta( def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
num_tokens = {} num_tokens = {}
# process scheduled_new_reqs # process scheduled_new_reqs
for req in scheduler_output.scheduled_new_reqs: for req in scheduler_output.scheduled_new_reqs:
req_id = req.req_id req_id = req.req_id
num_tokens[req_id] = ( num_tokens[req_id] = req.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]
req.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
# process scheduled_cached_reqs # process scheduled_cached_reqs
cached_reqs = scheduler_output.scheduled_cached_reqs cached_reqs = scheduler_output.scheduled_cached_reqs
for idx, req_id in enumerate(cached_reqs.req_ids): for idx, req_id in enumerate(cached_reqs.req_ids):
num_tokens[req_id] = ( num_tokens[req_id] = cached_reqs.num_computed_tokens[idx] + scheduler_output.num_scheduled_tokens[req_id]
cached_reqs.num_computed_tokens[idx] +
scheduler_output.num_scheduled_tokens[req_id])
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() - unallocated_req_ids = set(
self.allocated_req_ids - self.num_gpu_computed_tokens.keys() - self.allocated_req_ids - scheduler_output.num_scheduled_tokens.keys()
scheduler_output.num_scheduled_tokens.keys()) )
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", num_tokens, unallocated_req_ids)
num_tokens,
unallocated_req_ids)
metadata = CPUOffloadingConnectorMetadata( metadata = CPUOffloadingConnectorMetadata(
requests={}, requests={},
finished_req_ids=set(self.finished_req_ids), finished_req_ids=set(self.finished_req_ids),
@@ -222,22 +193,22 @@ class CPUOffloadingConnectorScheduler:
metadata.requests[req_id] = ReqMeta( metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids, gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []), cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output. num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id],
num_scheduled_tokens[req_id],
num_computed_tokens=req.num_computed_tokens, num_computed_tokens=req.num_computed_tokens,
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id], num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id]) num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id],
)
for idx, req_id in enumerate(cached_reqs.req_ids): for idx, req_id in enumerate(cached_reqs.req_ids):
gpu_block_ids = cached_reqs.new_block_ids[idx] gpu_block_ids = cached_reqs.new_block_ids[idx]
metadata.requests[req_id] = ReqMeta( metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids, gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []), cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output. num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id],
num_scheduled_tokens[req_id],
num_computed_tokens=cached_reqs.num_computed_tokens[idx], num_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx], num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx]) num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
)
self.num_gpu_computed_tokens.clear() self.num_gpu_computed_tokens.clear()
self.num_cpu_computed_tokens.clear() self.num_cpu_computed_tokens.clear()
self.allocated_req_ids.clear() self.allocated_req_ids.clear()
@@ -249,12 +220,10 @@ class CPUOffloadingConnectorScheduler:
request.get_hash_new_full_blocks = None request.get_hash_new_full_blocks = None
self.finished_req_ids.append(request.request_id) self.finished_req_ids.append(request.request_id)
# inform metadata server to record request, and free it after finish sending # inform metadata server to record request, and free it after finish sending
self.zmq_rpc_client.call("record_request_cache_and_free_slots", self.zmq_rpc_client.call("record_request_cache_and_free_slots", request)
request)
class CPUOffloadingConnectorWorker: class CPUOffloadingConnectorWorker:
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorWorker") logger.info("init CPUOffloadingConnectorWorker")
self.vllm_config = vllm_config self.vllm_config = vllm_config
@@ -289,7 +258,7 @@ class CPUOffloadingConnectorWorker:
def init_metadata_server(self, vllm_config: VllmConfig): def init_metadata_server(self, vllm_config: VllmConfig):
self.metadata_thread = threading.Thread( self.metadata_thread = threading.Thread(
target=MetadataServerProc.run_metadata_server, target=MetadataServerProc.run_metadata_server,
args=(vllm_config, ), args=(vllm_config,),
) )
self.metadata_thread.daemon = True self.metadata_thread.daemon = True
self.metadata_thread.start() self.metadata_thread.start()
@@ -304,18 +273,15 @@ class CPUOffloadingConnectorWorker:
logger.info(f"wait for metadata server to start, error: {e}") logger.info(f"wait for metadata server to start, error: {e}")
time.sleep(1) time.sleep(1)
def bind_connector_metadata( def bind_connector_metadata(self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
for req_id, req in connector_metadata.requests.items(): for req_id, req in connector_metadata.requests.items():
if req_id in self.requests: if req_id in self.requests:
self.requests[req_id].update(req) self.requests[req_id].update(req)
req = self.requests[req_id] req = self.requests[req_id]
else: else:
self.requests[req_id] = req self.requests[req_id] = req
for i in range(req.num_gpu_computed_tokens // self.block_size, for i in range(req.num_gpu_computed_tokens // self.block_size, req.num_computed_tokens // self.block_size):
req.num_computed_tokens // self.block_size): self.load_block_mapping.append((req.cpu_block_ids[i], req.gpu_block_ids[i]))
self.load_block_mapping.append(
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
for req_id in connector_metadata.finished_req_ids: for req_id in connector_metadata.finished_req_ids:
if req_id in self.requests: if req_id in self.requests:
self.save_input_queue.put((req_id, self.requests[req_id])) self.save_input_queue.put((req_id, self.requests[req_id]))
@@ -326,11 +292,11 @@ class CPUOffloadingConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]): def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
self.gpu_kv_caches = kv_caches self.gpu_kv_caches = kv_caches
model_config = self.vllm_config.model_config model_config = self.vllm_config.model_config
mla_config: Optional[MLAConfig] = None mla_config: MLAConfig | None = None
if model_config.use_mla: if model_config.use_mla:
mla_config = MLAConfig( mla_config = MLAConfig(
model_config.hf_text_config.kv_lora_rank, model_config.hf_text_config.kv_lora_rank, model_config.hf_text_config.qk_rope_head_dim
model_config.hf_text_config.qk_rope_head_dim) )
self.cpu_kv_caches = list( self.cpu_kv_caches = list(
self.zmq_rpc_client.call( self.zmq_rpc_client.call(
"init_cpu_kv_caches", "init_cpu_kv_caches",
@@ -338,7 +304,8 @@ class CPUOffloadingConnectorWorker:
self.tp_rank, self.tp_rank,
get_kv_cache_spec(self.vllm_config), get_kv_cache_spec(self.vllm_config),
mla_config, mla_config,
).values()) ).values()
)
def start_load_kv(self) -> None: def start_load_kv(self) -> None:
self.current_layer = 0 self.current_layer = 0
@@ -358,10 +325,8 @@ class CPUOffloadingConnectorWorker:
cpu_kv_caches = self.cpu_kv_caches[layer] cpu_kv_caches = self.cpu_kv_caches[layer]
with torch.npu.stream(self.load_stream): with torch.npu.stream(self.load_stream):
for cpu_block_id, gpu_block_id in self.load_block_mapping: for cpu_block_id, gpu_block_id in self.load_block_mapping:
for gpu_layer_part, cpu_layer_part in zip( for gpu_layer_part, cpu_layer_part in zip(gpu_kv_caches, cpu_kv_caches):
gpu_kv_caches, cpu_kv_caches): gpu_layer_part[gpu_block_id].copy_(cpu_layer_part[cpu_block_id], non_blocking=True)
gpu_layer_part[gpu_block_id].copy_(
cpu_layer_part[cpu_block_id], non_blocking=True)
def get_finished(self) -> set[str]: def get_finished(self) -> set[str]:
done_sending: set[str] = set() done_sending: set[str] = set()
@@ -380,8 +345,7 @@ class CPUOffloadingConnectorWorker:
self.done_sending_count[req_id] += 1 self.done_sending_count[req_id] += 1
other_ranks_finished_ids: list[str] = [] other_ranks_finished_ids: list[str] = []
for i in range(1, self.tp_world_size): for i in range(1, self.tp_world_size):
other_ranks_finished_ids.extend( other_ranks_finished_ids.extend(self.tp_group.recv_object(src=i))
self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids: for req_id in other_ranks_finished_ids:
self.done_sending_count[req_id] += 1 self.done_sending_count[req_id] += 1
all_done_sending: set[str] = set() all_done_sending: set[str] = set()
@@ -391,8 +355,7 @@ class CPUOffloadingConnectorWorker:
all_done_sending.add(req_id) all_done_sending.add(req_id)
# release cpu_kv_cache after request sending finished # release cpu_kv_cache after request sending finished
# to avoid rpc blocking, use thread to call rpc asynchronously # to avoid rpc blocking, use thread to call rpc asynchronously
sending_finished_thread = threading.Thread( sending_finished_thread = threading.Thread(target=self._sending_finished, args=(all_done_sending,))
target=self._sending_finished, args=(all_done_sending, ))
sending_finished_thread.daemon = True sending_finished_thread.daemon = True
sending_finished_thread.start() sending_finished_thread.start()
@@ -411,11 +374,10 @@ class CPUOffloadingConnectorWorker:
while True: while True:
req_id, req = self.save_input_queue.get() req_id, req = self.save_input_queue.get()
for i in range( for i in range(
req.num_cpu_computed_tokens // self.block_size, req.num_cpu_computed_tokens // self.block_size,
min((req.num_computed_tokens + req.num_scheduled_tokens) // min((req.num_computed_tokens + req.num_scheduled_tokens) // self.block_size, len(req.cpu_block_ids)),
self.block_size, len(req.cpu_block_ids))): ):
save_block_mapping.append( save_block_mapping.append((req.gpu_block_ids[i], req.cpu_block_ids[i]))
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
with torch.npu.stream(self.save_stream): with torch.npu.stream(self.save_stream):
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope). # MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
# non-MLA: kv_layer is list[tensor], typically means [k, v]. # non-MLA: kv_layer is list[tensor], typically means [k, v].
@@ -425,13 +387,9 @@ class CPUOffloadingConnectorWorker:
start, step = 0, 1 start, step = 0, 1
for i in range(start, len(save_block_mapping), step): for i in range(start, len(save_block_mapping), step):
gpu_block_id, cpu_block_id = save_block_mapping[i] gpu_block_id, cpu_block_id = save_block_mapping[i]
for cpu_kv_caches, gpu_kv_caches in zip( for cpu_kv_caches, gpu_kv_caches in zip(self.cpu_kv_caches, self.gpu_kv_caches.values()):
self.cpu_kv_caches, self.gpu_kv_caches.values()): for cpu_layer_part, gpu_layer_part in zip(cpu_kv_caches, gpu_kv_caches):
for cpu_layer_part, gpu_layer_part in zip( cpu_layer_part[cpu_block_id].copy_(gpu_layer_part[gpu_block_id], non_blocking=True)
cpu_kv_caches, gpu_kv_caches):
cpu_layer_part[cpu_block_id].copy_(
gpu_layer_part[gpu_block_id],
non_blocking=True)
self.save_stream.synchronize() self.save_stream.synchronize()
self.save_output_queue.put(req_id) self.save_output_queue.put(req_id)
save_block_mapping.clear() save_block_mapping.clear()
@@ -453,8 +411,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
if vllm_config.cache_config.cache_dtype == "auto": if vllm_config.cache_config.cache_dtype == "auto":
kv_cache_dtype = vllm_config.model_config.dtype kv_cache_dtype = vllm_config.model_config.dtype
else: else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[vllm_config.cache_config.cache_dtype]
vllm_config.cache_config.cache_dtype]
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase) attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
@@ -472,10 +429,8 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
# using DSA. Fix the spec in vLLM is the final way. # using DSA. Fix the spec in vLLM is the final way.
block_size = vllm_config.cache_config.block_size block_size = vllm_config.cache_config.block_size
kv_cache_spec[layer_name] = FullAttentionSpec( kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size, block_size=block_size, num_kv_heads=1, head_size=attn_module.head_size, dtype=kv_cache_dtype
num_kv_heads=1, )
head_size=attn_module.head_size,
dtype=kv_cache_dtype)
elif spec := attn_module.get_kv_cache_spec(vllm_config): elif spec := attn_module.get_kv_cache_spec(vllm_config):
kv_cache_spec[layer_name] = spec kv_cache_spec[layer_name] = spec
@@ -484,8 +439,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
if len(mamba_layers) > 0: if len(mamba_layers) > 0:
if vllm_config.cache_config.enable_prefix_caching: if vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError( raise NotImplementedError("Prefix caching is not supported for Mamba yet.")
"Prefix caching is not supported for Mamba yet.")
for layer_name, mamba_module in mamba_layers.items(): for layer_name, mamba_module in mamba_layers.items():
if spec := mamba_module.get_kv_cache_spec(vllm_config): if spec := mamba_module.get_kv_cache_spec(vllm_config):
kv_cache_spec[layer_name] = spec kv_cache_spec[layer_name] = spec

View File

@@ -1,9 +1,10 @@
import math import math
import os import os
import pickle import pickle
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
from typing import Any, Callable, Optional from typing import Any
import torch import torch
import vllm.envs as envs import vllm.envs as envs
@@ -14,8 +15,7 @@ from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.torch_utils import get_dtype_size from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import \ from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import CPUKVCacheManager
CPUKVCacheManager
@dataclass @dataclass
@@ -30,8 +30,7 @@ def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
if kv_transfer_config.kv_connector == "CPUOffloadingConnector": if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config return kv_transfer_config
elif kv_transfer_config.kv_connector == "MultiConnector": elif kv_transfer_config.kv_connector == "MultiConnector":
ktcs = kv_transfer_config.kv_connector_extra_config.get( ktcs = kv_transfer_config.kv_connector_extra_config.get("connectors")
"connectors")
for ktc in ktcs: for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc) kv_transfer_config = KVTransferConfig(**ktc)
if kv_transfer_config.kv_connector == "CPUOffloadingConnector": if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
@@ -44,7 +43,6 @@ class MetadataServer:
DEFAULT_CPU_SWAP_SPACE_GB = 800 DEFAULT_CPU_SWAP_SPACE_GB = 800
class ZMQRPCClient: class ZMQRPCClient:
def __init__(self, identity=None): def __init__(self, identity=None):
if identity is None: if identity is None:
identity = f"worker-{os.getpid()}-{id(self)}" identity = f"worker-{os.getpid()}-{id(self)}"
@@ -56,7 +54,8 @@ class MetadataServer:
zmq.DEALER, # type: ignore zmq.DEALER, # type: ignore
bind=False, bind=False,
identity=identity.encode(), identity=identity.encode(),
linger=0) linger=0,
)
def call(self, func_name: str, *args, **kwargs) -> Any: def call(self, func_name: str, *args, **kwargs) -> Any:
request = (func_name, args, kwargs) request = (func_name, args, kwargs)
@@ -74,11 +73,9 @@ class MetadataServer:
self.shared_memory_dict = memory_dict self.shared_memory_dict = memory_dict
result = {} result = {}
for key, shm in memory_dict.items(): for key, shm in memory_dict.items():
tensor = torch.frombuffer( tensor = torch.frombuffer(shm.buf, dtype=layer_dtype).reshape(layer_size)
shm.buf, dtype=layer_dtype).reshape(layer_size)
if mla_config is not None: if mla_config is not None:
tensor = tensor.split( tensor = tensor.split([mla_config.nope_dim, mla_config.rope_dim], dim=-1)
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
result[key] = tensor result[key] = tensor
return result return result
@@ -86,7 +83,7 @@ class MetadataServer:
# will be finalized by outer process # will be finalized by outer process
self.socket.close() self.socket.close()
self.ctx.term() self.ctx.term()
if hasattr(self, 'shared_memory_dict'): if hasattr(self, "shared_memory_dict"):
for shm in self.shared_memory_dict.values(): for shm in self.shared_memory_dict.values():
shm.close() shm.close()
@@ -96,7 +93,8 @@ class MetadataServer:
kv_transfer_config = get_cpu_offload_connector(vllm_config) kv_transfer_config = get_cpu_offload_connector(vllm_config)
assert kv_transfer_config is not None assert kv_transfer_config is not None
available_memory_gb = kv_transfer_config.get_from_extra_config( available_memory_gb = kv_transfer_config.get_from_extra_config(
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB) "cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB
)
self.available_memory = available_memory_gb * 1024 * 1024 * 1024 self.available_memory = available_memory_gb * 1024 * 1024 * 1024
logger.info(f"cpu swap space: {self.available_memory} bytes") logger.info(f"cpu swap space: {self.available_memory} bytes")
self.ctx = zmq.Context() # type: ignore self.ctx = zmq.Context() # type: ignore
@@ -105,7 +103,8 @@ class MetadataServer:
MetadataServer.METADATA_SERVER_ADDRESS, MetadataServer.METADATA_SERVER_ADDRESS,
zmq.ROUTER, # type: ignore zmq.ROUTER, # type: ignore
bind=True, bind=True,
linger=0) linger=0,
)
self.functions: dict[str, Callable] = { self.functions: dict[str, Callable] = {
"init_cpu_kv_caches": self.init_cpu_kv_caches, "init_cpu_kv_caches": self.init_cpu_kv_caches,
"post_init": self.post_init, "post_init": self.post_init,
@@ -133,15 +132,11 @@ class MetadataServer:
tp_rank: int, tp_rank: int,
kv_cache_specs: dict[str, AttentionSpec], kv_cache_specs: dict[str, AttentionSpec],
mla_config: MLAConfig, mla_config: MLAConfig,
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype, ) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype, MLAConfig]:
MLAConfig]:
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}") logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
# follow the assumption that each layer has the same spec # follow the assumption that each layer has the same spec
layer = next(iter(kv_cache_specs.values())) layer = next(iter(kv_cache_specs.values()))
assert all([ assert all([layer.page_size_bytes == any.page_size_bytes for any in kv_cache_specs.values()])
layer.page_size_bytes == any.page_size_bytes
for any in kv_cache_specs.values()
])
use_mla = isinstance(layer, MLAAttentionSpec) use_mla = isinstance(layer, MLAAttentionSpec)
# mla shares the same kv cache among different tp # mla shares the same kv cache among different tp
if use_mla: if use_mla:
@@ -154,30 +149,24 @@ class MetadataServer:
available_memory //= self.pipeline_parallel_size available_memory //= self.pipeline_parallel_size
available_memory //= len(kv_cache_specs) available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes num_blocks = available_memory // layer.page_size_bytes
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads, layer_size = (num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore
layer.head_size) # type: ignore
else: else:
available_memory //= self.world_size available_memory //= self.world_size
available_memory //= len(kv_cache_specs) available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes num_blocks = available_memory // layer.page_size_bytes
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads, layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore
layer.head_size) # type: ignore
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype) nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
for layer_name in kv_cache_specs.keys(): for layer_name in kv_cache_specs:
# only this format can share during ZeroMQ+pickle # only this format can share during ZeroMQ+pickle
shared_memory_dict[ shared_memory_dict[layer_name] = MetadataServer._safe_create_shared_memory(
layer_name] = MetadataServer._safe_create_shared_memory( f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes) )
if use_mla: if use_mla:
assert mla_config is not None assert mla_config is not None
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
self.shared_memory[(pp_rank, self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, mla_config)
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, mla_config)
else: else:
self.shared_memory[(pp_rank, self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, None)
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, None)
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks: if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
self.num_cpu_blocks = num_blocks self.num_cpu_blocks = num_blocks
self.layer = layer self.layer = layer
@@ -185,23 +174,20 @@ class MetadataServer:
def post_init(self): def post_init(self):
# different processors in data parallel may call multiple times # different processors in data parallel may call multiple times
if hasattr(self, 'cpu_block_manager'): if hasattr(self, "cpu_block_manager"):
return return
# do shared_memory() at least once # do shared_memory() at least once
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}") logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
assert self.num_cpu_blocks >= 0 assert self.num_cpu_blocks >= 0
self.cpu_block_manager = CPUKVCacheManager(self.layer, self.cpu_block_manager = CPUKVCacheManager(self.layer, self.num_cpu_blocks)
self.num_cpu_blocks) self.functions.update(
self.functions.update({ {
"get_matched_num_and_touch": "get_matched_num_and_touch": self.cpu_block_manager.get_matched_num_and_touch,
self.cpu_block_manager.get_matched_num_and_touch, "allocate_slots": self.cpu_block_manager.allocate_slots,
"allocate_slots": "record_request_cache_and_free_slots": self.cpu_block_manager.record_request_cache_and_free_slots,
self.cpu_block_manager.allocate_slots, "cache_and_free_slots": self.cpu_block_manager.cache_and_free_slots,
"record_request_cache_and_free_slots": }
self.cpu_block_manager.record_request_cache_and_free_slots, )
"cache_and_free_slots":
self.cpu_block_manager.cache_and_free_slots,
})
def serve_step(self): def serve_step(self):
client_id = self.socket.recv() client_id = self.socket.recv()
@@ -228,8 +214,7 @@ class MetadataServer:
def shutdown(self): def shutdown(self):
self.socket.close() self.socket.close()
self.ctx.term() self.ctx.term()
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace( socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace("ipc://", "")
"ipc://", "")
if os.path.exists(socket_path): if os.path.exists(socket_path):
os.remove(socket_path) os.remove(socket_path)
for cached in self.shared_memory.values(): for cached in self.shared_memory.values():
@@ -239,11 +224,9 @@ class MetadataServer:
class MetadataServerProc: class MetadataServerProc:
@staticmethod @staticmethod
def run_metadata_server(vllm_config: VllmConfig): def run_metadata_server(vllm_config: VllmConfig):
if (not vllm_config.cache_config.enable_prefix_caching if not vllm_config.cache_config.enable_prefix_caching or get_cpu_offload_connector(vllm_config) is None:
or get_cpu_offload_connector(vllm_config) is None):
return return
shutdown_requested = False shutdown_requested = False
@@ -257,7 +240,7 @@ class MetadataServerProc:
# Either SIGTERM or SIGINT will terminate the worker # Either SIGTERM or SIGINT will terminate the worker
# signal.signal(signal.SIGTERM, _signal_handler) # signal.signal(signal.SIGTERM, _signal_handler)
# signal.signal(signal.SIGINT, _signal_handler) # signal.signal(signal.SIGINT, _signal_handler)
metadata_server: Optional[MetadataServer] = None metadata_server: MetadataServer | None = None
try: try:
metadata_server = MetadataServer(vllm_config) metadata_server = MetadataServer(vllm_config)
logger.info("Metadata server started.") logger.info("Metadata server started.")

View File

@@ -4,19 +4,21 @@ from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from ucm.integration.vllm.ucm_connector import UCMConnector from ucm.integration.vllm.ucm_connector import UCMConnector
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
# isort: off # isort: off
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics, KVConnectorStats, PromMetric, PromMetricT) KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -25,16 +27,13 @@ if TYPE_CHECKING:
class UCMConnectorV1(KVConnectorBase_V1): class UCMConnectorV1(KVConnectorBase_V1):
def __init__( def __init__(
self, self,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: "KVCacheConfig", kv_cache_config: "KVCacheConfig",
): ):
super().__init__(vllm_config=vllm_config, super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
role=role,
kv_cache_config=kv_cache_config)
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
ImplCls = UCMConnector ImplCls = UCMConnector
@@ -60,8 +59,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
""" """
self._ucm_engine.register_kv_caches(kv_caches) self._ucm_engine.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext", def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
**kwargs: Any) -> None:
""" """
Start loading the KV cache from the connector to vLLM's paged Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the KV buffer. This is called from the forward context before the
@@ -110,8 +108,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
attn_metadata (AttentionMetadata): the attention metadata. attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation. **kwargs: additional arguments for the save operation.
""" """
self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
**kwargs)
def wait_for_save(self) -> None: def wait_for_save(self) -> None:
""" """
@@ -131,8 +128,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
""" """
self._ucm_engine.clear_connector_metadata() self._ucm_engine.clear_connector_metadata()
def bind_connector_metadata( def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
self, connector_metadata: KVConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler. """Set the connector metadata from the scheduler.
This function should be called by the model runner every time This function should be called by the model runner every time
@@ -175,20 +171,15 @@ class UCMConnectorV1(KVConnectorBase_V1):
the number of tokens that can be loaded from the the number of tokens that can be loaded from the
external KV cache beyond what is already computed. external KV cache beyond what is already computed.
""" """
return self._ucm_engine.get_num_new_matched_tokens( return self._ucm_engine.get_num_new_matched_tokens(request, num_computed_tokens)
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request", def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int) -> None:
blocks: "KVCacheBlocks",
num_external_tokens: int) -> None:
""" """
Update KVConnector state after block allocation. Update KVConnector state after block allocation.
""" """
self._ucm_engine.update_state_after_alloc(request, blocks, self._ucm_engine.update_state_after_alloc(request, blocks, num_external_tokens)
num_external_tokens)
def build_connector_meta( def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
""" """
Build the connector metadata for this step. Build the connector metadata for this step.
@@ -222,10 +213,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
# ============================== # ==============================
@classmethod @classmethod
def build_kv_connector_stats( def build_kv_connector_stats(cls, data: dict[str, Any] | None = None) -> Optional["KVConnectorStats"]:
cls,
data: dict[str, Any] | None = None
) -> Optional["KVConnectorStats"]:
""" """
KVConnectorStats resolution method. This method allows dynamically KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object, registered connectors to return their own KVConnectorStats object,

View File

@@ -1,19 +1,16 @@
import ipaddress
import threading import threading
from typing import Optional
from mooncake.engine import TransferEngine # type: ignore from mooncake.engine import TransferEngine # type: ignore
class GlobalTE(): class GlobalTE:
def __init__(self): def __init__(self):
self.transfer_engine = None self.transfer_engine = None
self.is_register_buffer: bool = False self.is_register_buffer: bool = False
self.transfer_engine_lock = threading.Lock() self.transfer_engine_lock = threading.Lock()
self.register_buffer_lock = threading.Lock() self.register_buffer_lock = threading.Lock()
def get_transfer_engine(self, hostname: str, device_name: Optional[str]): def get_transfer_engine(self, hostname: str, device_name: str | None):
if self.transfer_engine is None: if self.transfer_engine is None:
with self.transfer_engine_lock: with self.transfer_engine_lock:
# Double-Checked Locking # Double-Checked Locking
@@ -22,12 +19,9 @@ class GlobalTE():
raise RuntimeError("mooncake is not available") raise RuntimeError("mooncake is not available")
self.transfer_engine = TransferEngine() self.transfer_engine = TransferEngine()
device_name = device_name if device_name is not None else "" device_name = device_name if device_name is not None else ""
ret_value = self.transfer_engine.initialize( ret_value = self.transfer_engine.initialize(hostname, "P2PHANDSHAKE", "ascend", device_name)
hostname, "P2PHANDSHAKE", "ascend", device_name)
if ret_value != 0: if ret_value != 0:
raise RuntimeError( raise RuntimeError(f"TransferEngine initialization failed with ret_value: {ret_value}")
f"TransferEngine initialization failed with ret_value: {ret_value}"
)
return self.transfer_engine return self.transfer_engine
def register_buffer(self, ptrs: list[int], sizes: list[int]): def register_buffer(self, ptrs: list[int], sizes: list[int]):

View File

@@ -6,8 +6,7 @@ import torch.distributed as dist
from vllm_ascend.distributed.parallel_state import get_p_tp_group from vllm_ascend.distributed.parallel_state import get_p_tp_group
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, value: torch.TensorType):
value: torch.TensorType):
if pd_tp_ratio <= 1: if pd_tp_ratio <= 1:
return None, None return None, None
elif key is None or value is None: elif key is None or value is None:
@@ -20,22 +19,17 @@ def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor): def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
num_kv_heads = input_tensor.size(1) num_kv_heads = input_tensor.size(1)
output_tensor = torch.zeros_like(input_tensor) output_tensor = torch.zeros_like(input_tensor)
dist.all_to_all_single(output_tensor, dist.all_to_all_single(output_tensor, input_tensor, group=get_p_tp_group().device_group)
input_tensor,
group=get_p_tp_group().device_group)
input_tensor = 0 input_tensor = 0
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads) result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
output_tensor = 0 output_tensor = 0
return result return result
def rearrange_output(base_output: torch.Tensor, cut_num: int, def rearrange_output(base_output: torch.Tensor, cut_num: int, num_kv_heads: int):
num_kv_heads: int):
size_0 = base_output.size(0) size_0 = base_output.size(0)
if size_0 % cut_num != 0: if size_0 % cut_num != 0:
raise ValueError( raise ValueError(f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]")
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
)
chunk_size = size_0 // cut_num chunk_size = size_0 // cut_num
reshaped = base_output.view(cut_num, chunk_size, -1) reshaped = base_output.view(cut_num, chunk_size, -1)
transposed = reshaped.transpose(0, 1) transposed = reshaped.transpose(0, 1)
@@ -46,16 +40,13 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr() data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size() offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):] return tensor[int(offset) :]
def get_transfer_timeout_value(): def get_transfer_timeout_value():
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "") ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
if len(ascend_transfer_timeout) > 0: if len(ascend_transfer_timeout) > 0:
return int(ascend_transfer_timeout) return int(ascend_transfer_timeout)
hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT', hccl_rdma_timeout = int(os.getenv("HCCL_RDMA_TIMEOUT", "20")) # type: ignore
'20')) # type: ignore hccl_rdma_retry_cnt = int(os.getenv("HCCL_RDMA_RETRY_CNT", "7")) # type: ignore
hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT', return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000)
'7')) # type: ignore
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
3000)

View File

@@ -4,8 +4,7 @@ from vllm.logger import init_logger
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import AttentionBackend # type: ignore from vllm.v1.attention.backend import AttentionBackend # type: ignore
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler, from vllm.v1.kv_offload.worker.worker import OffloadingHandler, TransferResult, TransferSpec
TransferResult, TransferSpec)
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -44,7 +43,6 @@ def expand_block_ids(
class CpuNpuOffloadingHandler(OffloadingHandler): class CpuNpuOffloadingHandler(OffloadingHandler):
def __init__( def __init__(
self, self,
gpu_block_size: int, gpu_block_size: int,
@@ -81,20 +79,22 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor
logger.debug("Allocating CPU tensor of shape %r", cpu_shape) logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
self.cpu_tensors.append(( self.cpu_tensors.append(
torch.zeros( (
cpu_shape, torch.zeros(
dtype=gpu_tensor[0].dtype, cpu_shape,
device="cpu", dtype=gpu_tensor[0].dtype,
pin_memory=pin_memory, device="cpu",
), pin_memory=pin_memory,
torch.zeros( ),
cpu_shape, torch.zeros(
dtype=gpu_tensor[0].dtype, cpu_shape,
device="cpu", dtype=gpu_tensor[0].dtype,
pin_memory=pin_memory, device="cpu",
), pin_memory=pin_memory,
)) ),
)
)
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
logger.info("start transfer_async...") logger.info("start transfer_async...")
@@ -123,9 +123,7 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor
src_sub_block_count = src_blocks.size * src_block_size_factor src_sub_block_count = src_blocks.size * src_block_size_factor
assert ( assert src_sub_block_count == dst_blocks.size * dst_block_size_factor - dst_sub_blocks_to_skip
src_sub_block_count == dst_blocks.size * dst_block_size_factor -
dst_sub_blocks_to_skip)
src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64) src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64)
expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0]) expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0])
@@ -137,18 +135,14 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
) )
src_to_dst_tensor = torch.from_numpy(src_to_dst) src_to_dst_tensor = torch.from_numpy(src_to_dst)
event = self.events_pool.pop( event = self.events_pool.pop() if self.events_pool else torch.npu.Event()
) if self.events_pool else torch.npu.Event()
with torch.npu.stream(stream): with torch.npu.stream(stream):
for src_tensor, dst_tensor in zip(src_tensors, dst_tensors): for src_tensor, dst_tensor in zip(src_tensors, dst_tensors):
src_key_cache, src_value_cache = src_tensor[0], src_tensor[1] src_key_cache, src_value_cache = src_tensor[0], src_tensor[1]
dst_key_cache, dst_value_cache = dst_tensor[0], dst_tensor[1] dst_key_cache, dst_value_cache = dst_tensor[0], dst_tensor[1]
torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache, torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
src_to_dst_tensor) torch.ops._C_ascend.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
torch.ops._C_ascend.swap_blocks(src_value_cache,
dst_value_cache,
src_to_dst_tensor)
event.record(stream) event.record(stream)
@@ -175,4 +169,4 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
event = self.transfer_events.get(job_id) event = self.transfer_events.get(job_id)
if event is not None: if event is not None:
# This will block until the NPU event is complete # This will block until the NPU event is complete
event.synchronize() event.synchronize()

View File

@@ -1,48 +1,40 @@
from collections.abc import Iterator from collections.abc import Iterator
from typing import Optional
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend # type: ignore from vllm.v1.attention.backend import AttentionBackend # type: ignore
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.backends.cpu import CPUBackend from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import OffloadingHandler from vllm.v1.kv_offload.worker.worker import OffloadingHandler
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm_ascend.kv_offload.cpu_npu import CpuNpuOffloadingHandler from vllm_ascend.kv_offload.cpu_npu import CpuNpuOffloadingHandler
class NPUOffloadingSpec(OffloadingSpec): class NPUOffloadingSpec(OffloadingSpec):
def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig | None = None):
def __init__(self,
vllm_config: VllmConfig,
kv_cache_config: Optional[KVCacheConfig] = None):
super().__init__(vllm_config, kv_cache_config) super().__init__(vllm_config, kv_cache_config)
num_cpu_blocks = self.extra_config.get("num_cpu_blocks") num_cpu_blocks = self.extra_config.get("num_cpu_blocks")
if not num_cpu_blocks: if not num_cpu_blocks:
raise Exception( raise Exception("num_cpu_blocks must be specified in kv_connector_extra_config")
"num_cpu_blocks must be specified in kv_connector_extra_config"
)
self.num_cpu_blocks: int = num_cpu_blocks self.num_cpu_blocks: int = num_cpu_blocks
# scheduler-side # scheduler-side
self._manager: Optional[OffloadingManager] = None self._manager: OffloadingManager | None = None
# worker-side # worker-side
self._handler: Optional[OffloadingHandler] = None self._handler: OffloadingHandler | None = None
def get_manager(self) -> OffloadingManager: def get_manager(self) -> OffloadingManager:
if not self._manager: if not self._manager:
kv_events_config = self.vllm_config.kv_events_config kv_events_config = self.vllm_config.kv_events_config
enable_events = (kv_events_config is not None enable_events = kv_events_config is not None and kv_events_config.enable_kv_cache_events
and kv_events_config.enable_kv_cache_events)
self._manager = LRUOffloadingManager( self._manager = LRUOffloadingManager(
CPUBackend(block_size=self.offloaded_block_size, CPUBackend(block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks),
num_blocks=self.num_cpu_blocks),
enable_events=enable_events, enable_events=enable_events,
) )
return self._manager return self._manager
@@ -51,8 +43,7 @@ class NPUOffloadingSpec(OffloadingSpec):
self, self,
kv_caches: dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]], attn_backends: dict[str, type[AttentionBackend]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
OffloadingHandler]]:
if not self._handler: if not self._handler:
self._handler = CpuNpuOffloadingHandler( self._handler = CpuNpuOffloadingHandler(
attn_backends=attn_backends, attn_backends=attn_backends,

View File

@@ -16,11 +16,13 @@
import torch import torch
def bgmv_shrink(inputs: torch.Tensor, def bgmv_shrink(
lora_a_weights: torch.Tensor, inputs: torch.Tensor,
output_tensor: torch.Tensor, lora_a_weights: torch.Tensor,
lora_indices_tensor: torch.Tensor, output_tensor: torch.Tensor,
scaling: float = 1.0): lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
return torch.ops._C_ascend.bgmv_shrink( return torch.ops._C_ascend.bgmv_shrink(
inputs, inputs,
lora_a_weights, lora_a_weights,
@@ -30,11 +32,13 @@ def bgmv_shrink(inputs: torch.Tensor,
) )
def bgmv_expand(inputs: torch.Tensor, def bgmv_expand(
lora_b_weights: torch.Tensor, inputs: torch.Tensor,
output_tensor: torch.Tensor, lora_b_weights: torch.Tensor,
lora_indices_tensor: torch.Tensor, output_tensor: torch.Tensor,
add_inputs: bool = True): lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
return torch.ops._C_ascend.bgmv_expand( return torch.ops._C_ascend.bgmv_expand(
inputs, inputs,
lora_b_weights, lora_b_weights,
@@ -45,16 +49,18 @@ def bgmv_expand(inputs: torch.Tensor,
) )
def bgmv_expand_slice(inputs: torch.Tensor, def bgmv_expand_slice(
lora_b_weights: torch.Tensor, inputs: torch.Tensor,
output_tensor: torch.Tensor, lora_b_weights: torch.Tensor,
lora_indices_tensor: torch.Tensor, output_tensor: torch.Tensor,
slice_offset: int, lora_indices_tensor: torch.Tensor,
slice_size: int, slice_offset: int,
add_inputs: bool = True): slice_size: int,
return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights, add_inputs: bool = True,
lora_indices_tensor, output_tensor, ):
slice_offset, slice_size) return torch.ops._C_ascend.bgmv_expand(
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset, slice_size
)
def sgmv_shrink( def sgmv_shrink(
@@ -69,21 +75,23 @@ def sgmv_shrink(
token_nums: int, token_nums: int,
scaling: float, scaling: float,
): ):
return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights, return torch.ops._C_ascend.sgmv_shrink(
lora_indices_tensor, seq_len_tensor, inputs, lora_a_weights, lora_indices_tensor, seq_len_tensor, output_tensor, scaling
output_tensor, scaling) )
def sgmv_expand(inputs: torch.Tensor, def sgmv_expand(
lora_b_weights: torch.Tensor, inputs: torch.Tensor,
output_tensor: torch.Tensor, lora_b_weights: torch.Tensor,
b_seq_start_loc: torch.Tensor, output_tensor: torch.Tensor,
seq_len_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor,
lora_indices_tensor: torch.Tensor, seq_len_tensor: torch.Tensor,
batches: int, lora_indices_tensor: torch.Tensor,
max_seq_length: int, batches: int,
token_nums: int, max_seq_length: int,
add_inputs: bool = False): token_nums: int,
add_inputs: bool = False,
):
return torch.ops._C_ascend.sgmv_expand( return torch.ops._C_ascend.sgmv_expand(
inputs, inputs,
lora_b_weights, lora_b_weights,
@@ -95,19 +103,20 @@ def sgmv_expand(inputs: torch.Tensor,
) )
def sgmv_expand_slice(inputs: torch.Tensor, def sgmv_expand_slice(
lora_b_weights: torch.Tensor, inputs: torch.Tensor,
output_tensor: torch.Tensor, lora_b_weights: torch.Tensor,
b_seq_start_loc: torch.Tensor, output_tensor: torch.Tensor,
seq_len_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor,
lora_indices_tensor: torch.Tensor, seq_len_tensor: torch.Tensor,
batches: int, lora_indices_tensor: torch.Tensor,
max_seq_length: int, batches: int,
token_nums: int, max_seq_length: int,
slice_offset: int, token_nums: int,
slice_size: int, slice_offset: int,
add_inputs: bool = False): slice_size: int,
return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights, add_inputs: bool = False,
lora_indices_tensor, seq_len_tensor, ):
output_tensor, slice_offset, return torch.ops._C_ascend.sgmv_expand(
slice_size) inputs, lora_b_weights, lora_indices_tensor, seq_len_tensor, output_tensor, slice_offset, slice_size
)

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional, Tuple, Union from collections.abc import Callable
import torch import torch
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
@@ -18,26 +18,30 @@ class PunicaWrapperNPU(PunicaWrapperBase):
Multi-LoRA, and to provide the interface for the pytorch punica ops. Multi-LoRA, and to provide the interface for the pytorch punica ops.
""" """
def __init__(self, max_num_batched_tokens: int, max_batches: int, def __init__(self, max_num_batched_tokens: int, max_batches: int, device: torch.device | str, **kwargs):
device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
refresh_all_lora_classes() refresh_all_lora_classes()
self.lora_config = kwargs.get("lora_config") self.lora_config = kwargs.get("lora_config")
if get_ascend_device_type() == AscendDeviceType._310P or ( if get_ascend_device_type() == AscendDeviceType._310P or (
self.lora_config is not None self.lora_config is not None and self.lora_config.max_lora_rank >= 128
and self.lora_config.max_lora_rank >= 128): ):
from vllm.lora.ops.torch_ops import (bgmv_expand, from vllm.lora.ops.torch_ops import (
bgmv_expand_slice, bgmv_expand,
bgmv_shrink, sgmv_expand, bgmv_expand_slice,
sgmv_expand_slice, bgmv_shrink,
sgmv_shrink) sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
else: else:
from vllm_ascend.lora.lora_ops import (bgmv_expand, from vllm_ascend.lora.lora_ops import (
bgmv_expand_slice, bgmv_expand,
bgmv_shrink, sgmv_expand, bgmv_expand_slice,
sgmv_expand_slice, bgmv_shrink,
sgmv_shrink) sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
self.bgmv_expand = bgmv_expand self.bgmv_expand = bgmv_expand
self.bgmv_expand_slice = bgmv_expand_slice self.bgmv_expand_slice = bgmv_expand_slice
self.bgmv_shrink = bgmv_shrink self.bgmv_shrink = bgmv_shrink
@@ -52,7 +56,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
scale: float, scale: float,
): ):
#No LoRA request, so return directly # No LoRA request, so return directly
if self.no_lora: if self.no_lora:
return return
self.sgmv_shrink( self.sgmv_shrink(
@@ -79,7 +83,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
add_inputs: bool, add_inputs: bool,
): ):
#No LoRA request, so return directly # No LoRA request, so return directly
if self.no_lora: if self.no_lora:
return return
self.sgmv_expand( self.sgmv_expand(
@@ -108,7 +112,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
y_slice_size: int, y_slice_size: int,
add_inputs: bool, add_inputs: bool,
): ):
#No LoRA request, so return directly # No LoRA request, so return directly
if self.no_lora: if self.no_lora:
return return
self.sgmv_expand_slice( self.sgmv_expand_slice(
@@ -130,8 +134,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
y_slice_size: int, y_slice_size: int,
add_inputs: bool, add_inputs: bool,
): ):
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs)
y_offset, y_slice_size, add_inputs)
def _apply_expand( def _apply_expand(
self, self,
@@ -148,13 +151,10 @@ class PunicaWrapperNPU(PunicaWrapperBase):
GEMM of lora'b. GEMM of lora'b.
""" """
expand_slice_fun: Callable = (self._expand_slice_prefill expand_slice_fun: Callable = self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float):
w_t_all: torch.Tensor, scale: float):
""" """
Perform the ` y+=x@w_t_all` computation, which is suitable for the Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a. GEMM of lora'a.
@@ -165,14 +165,18 @@ class PunicaWrapperNPU(PunicaWrapperBase):
""" """
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (self._shrink_prefill shrink_fun: Callable = self._shrink_prefill if self.is_prefill else self._shrink_decode
if self.is_prefill else self._shrink_decode)
shrink_fun(y, x, w_t_all, scale) shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org) y = y.view_as(y_org)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], def add_shrink(
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], self,
scale: float, **kwargs): y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
):
""" """
Performs GEMM for multiple slices of lora_a. Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the When `is_prefill is` true, it indicates that it is currently the
@@ -194,18 +198,19 @@ class PunicaWrapperNPU(PunicaWrapperBase):
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
# TODO fuse these kernels # TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)): for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
scale)
def add_expand(self, def add_expand(
y: torch.Tensor, self,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor], y: torch.Tensor,
lora_b_stacked: Tuple[torch.Tensor, ...], x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: Tuple[int, ...], lora_bias_stacked: tuple[torch.Tensor, ...] | None,
offset_start: int = 0, output_slices: tuple[int, ...],
add_inputs=True, offset_start: int = 0,
**kwargs) -> None: add_inputs=True,
**kwargs,
) -> None:
""" """
Performs GEMM and bias addition for multiple slices of lora_b. Performs GEMM and bias addition for multiple slices of lora_b.
@@ -229,8 +234,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
offset_left = offset_start offset_left = offset_start
if lora_bias_stacked is not None: if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices, self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked)
lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)): for slice_idx in range(len(lora_b_stacked)):
self._apply_expand( self._apply_expand(
y, y,
@@ -243,12 +247,9 @@ class PunicaWrapperNPU(PunicaWrapperBase):
offset_left += output_slices[slice_idx] offset_left += output_slices[slice_idx]
y = y.view_as(y_org) y = y.view_as(y_org)
def add_lora_embedding(self, def add_lora_embedding(
y: torch.Tensor, self, y: torch.Tensor, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs
x: torch.Tensor, ) -> None:
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs) -> None:
""" """
Applies lora specifically for VocabParallelEmbeddingWithLoRA. Applies lora specifically for VocabParallelEmbeddingWithLoRA.
@@ -263,21 +264,22 @@ class PunicaWrapperNPU(PunicaWrapperBase):
""" """
# Embedding layer only need expand op # Embedding layer only need expand op
expand_fun: Callable = (self._expand_prefill expand_fun: Callable = self._expand_prefill if self.is_prefill else self._expand_decode
if self.is_prefill else self._expand_decode)
x = x.to(torch.float32) x = x.to(torch.float32)
expand_fun(y, x, lora_b_stacked, add_inputs) expand_fun(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(self, def add_lora_linear(
y: torch.Tensor, self,
x: torch.Tensor, y: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...], x: torch.Tensor,
lora_b_stacked: Tuple[torch.Tensor, ...], lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: Tuple[int, ...], scale: float,
*, output_slices: tuple[int, ...],
buffer: Optional[Tuple[torch.Tensor, ...]] = None, *,
**kwargs) -> None: buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> None:
""" """
Applicable to linear-related lora. Applicable to linear-related lora.
@@ -308,27 +310,22 @@ class PunicaWrapperNPU(PunicaWrapperBase):
# We set the buffer to be float32 by default, consistent with the # We set the buffer to be float32 by default, consistent with the
# triton op # triton op
buffer = tuple( buffer = tuple(
torch.zeros( torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) for _ in range(len(output_slices))
(x.size(0), r), dtype=torch.float32, device=x.device) )
for _ in range(len(output_slices)))
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y, self.add_expand(y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs)
buffer,
lora_b_stacked,
None,
output_slices,
add_inputs=True,
**kwargs)
def add_lora_logits(self, def add_lora_logits(
y: torch.Tensor, self,
x: torch.Tensor, y: torch.Tensor,
lora_a_stacked: torch.Tensor, x: torch.Tensor,
lora_b_stacked: torch.Tensor, lora_a_stacked: torch.Tensor,
scale, lora_b_stacked: torch.Tensor,
*, scale,
buffer: Optional[torch.Tensor] = None, *,
**kwargs) -> None: buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
""" """
Applies lora specifically for LogitsProcessorWithLoRA. Applies lora specifically for LogitsProcessorWithLoRA.
@@ -350,9 +347,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
r = lora_b_stacked.size(-1) r = lora_b_stacked.size(-1)
if buffer is None: if buffer is None:
buffer = torch.zeros((x.size(0), r), buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
dtype=torch.float32,
device=x.device)
indices = self.sampler_indices indices = self.sampler_indices

View File

@@ -1,91 +1,75 @@
from typing import Optional
import vllm import vllm
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA, from vllm.lora.layers import (
MergedColumnParallelLinearWithLoRA, ColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
QKVParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA, QKVParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA) RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA,
)
from vllm.lora.layers.utils import _not_fully_sharded_can_replace from vllm.lora.layers.utils import _not_fully_sharded_can_replace
from vllm_ascend.ops.linear import (AscendColumnParallelLinear, from vllm_ascend.ops.linear import (
AscendMergedColumnParallelLinear, AscendColumnParallelLinear,
AscendQKVParallelLinear, AscendMergedColumnParallelLinear,
AscendRowParallelLinear) AscendQKVParallelLinear,
from vllm_ascend.ops.vocab_parallel_embedding import \ AscendRowParallelLinear,
AscendVocabParallelEmbedding )
from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@classmethod @classmethod
def can_replace_layer( def can_replace_layer(
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: PretrainedConfig | None,
) -> bool: ) -> bool:
return type(source_layer) is AscendColumnParallelLinear return type(source_layer) is AscendColumnParallelLinear
class AscendMergedColumnParallelLinearWithLoRA( class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
MergedColumnParallelLinearWithLoRA):
@classmethod @classmethod
def can_replace_layer( def can_replace_layer(
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: PretrainedConfig | None,
) -> bool: ) -> bool:
return type(source_layer) is AscendMergedColumnParallelLinear return type(source_layer) is AscendMergedColumnParallelLinear
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA): class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
@classmethod @classmethod
def can_replace_layer( def can_replace_layer(
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: PretrainedConfig | None,
) -> bool: ) -> bool:
return type(source_layer) is AscendRowParallelLinear return type(source_layer) is AscendRowParallelLinear
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA): class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
@classmethod @classmethod
def can_replace_layer( def can_replace_layer(
cls, cls,
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: PretrainedConfig | None,
) -> bool: ) -> bool:
return type(source_layer) is AscendVocabParallelEmbedding return type(source_layer) is AscendVocabParallelEmbedding
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA): class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: list,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(
packed_modules_list) == 1
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
def can_replace_layer( def can_replace_layer(
@@ -93,18 +77,28 @@ class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
source_layer: nn.Module, source_layer: nn.Module,
lora_config: LoRAConfig, lora_config: LoRAConfig,
packed_modules_list: list, packed_modules_list: list,
model_config: Optional[PretrainedConfig], model_config: PretrainedConfig | None,
) -> bool: ) -> bool:
return (type(source_layer) is AscendQKVParallelLinear return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1
and len(packed_modules_list) == 3)
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
def refresh_all_lora_classes(): def refresh_all_lora_classes():
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add( vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
AscendMergedColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add( vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)
AscendMergedQKVParallelLinearWithLoRA)