### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|
`.../distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py`
|
|
`vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py` |
| `
vllm_ascend/distributed/kv_transfer/utils/mooncake_transfer_engine.py` |
| ` vllm_ascend/distributed/kv_transfer/utils/utils.py` |
| ` vllm_ascend/kv_offload/cpu_npu.py` |
| ` vllm_ascend/kv_offload/npu.py` |
| ` vllm_ascend/lora/lora_ops.py` |
| ` vllm_ascend/lora/punica_npu.py` |
| ` vllm_ascend/lora/utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
4
mypy.ini
4
mypy.ini
@@ -1,6 +1,8 @@
|
||||
[mypy]
|
||||
; warn_return_any = True
|
||||
warn_unused_configs = True
|
||||
; disable errors about unchecked annotations for now.
|
||||
disable_error_code = annotation-unchecked
|
||||
|
||||
; Suppress all missing import errors from torch_npu for mypy.
|
||||
[mypy-torch_npu.*]
|
||||
@@ -31,4 +33,4 @@ ignore_missing_imports = True
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-ucm.*]
|
||||
ignore_missing_imports = True
|
||||
ignore_missing_imports = True
|
||||
|
||||
@@ -51,11 +51,6 @@ line-length = 120
|
||||
# Folder to be modified
|
||||
exclude = [
|
||||
"tests/**",
|
||||
# (5)
|
||||
"vllm_ascend/distributed/kv_transfer/kv_pool/**",
|
||||
"vllm_ascend/distributed/kv_transfer/utils/**",
|
||||
"vllm_ascend/kv_offload/**",
|
||||
"vllm_ascend/lora/**",
|
||||
# (7)
|
||||
"vllm_ascend/quantization/**",
|
||||
"vllm_ascend/sample/*.py",
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
@@ -17,31 +16,27 @@ from vllm.v1.request import Request
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler import (
|
||||
KVPoolScheduler, get_zmq_rpc_path_lookup)
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import \
|
||||
KVPoolWorker
|
||||
KVPoolScheduler,
|
||||
get_zmq_rpc_path_lookup,
|
||||
)
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import KVPoolWorker
|
||||
|
||||
|
||||
class AscendStoreConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config)
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None):
|
||||
super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
|
||||
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"use_layerwise", False)
|
||||
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get("use_layerwise", False)
|
||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_put", False)
|
||||
"consumer_is_to_put", False
|
||||
)
|
||||
|
||||
connector_name = vllm_config.kv_transfer_config.kv_connector
|
||||
if connector_name == "MooncakeConnectorStoreV1":
|
||||
logger.warning(
|
||||
"It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future."
|
||||
"It is recommended to use the AscendStoreConnector, "
|
||||
"as the MoonCakeStoreConnector will be removed in the future."
|
||||
)
|
||||
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
@@ -49,8 +44,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
self.sended_but_unfinished_reqs: set[str] = set()
|
||||
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = KVPoolScheduler(vllm_config,
|
||||
self.use_layerwise)
|
||||
self.connector_scheduler = KVPoolScheduler(vllm_config, self.use_layerwise)
|
||||
else:
|
||||
self.connector_worker = KVPoolWorker(
|
||||
vllm_config,
|
||||
@@ -59,27 +53,19 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
|
||||
assert self.connector_worker is not None
|
||||
if vllm_config.parallel_config.rank == 0:
|
||||
self.lookup_server = LookupKeyServer(self.connector_worker,
|
||||
vllm_config,
|
||||
self.use_layerwise)
|
||||
self.lookup_server = LookupKeyServer(self.connector_worker, vllm_config, self.use_layerwise)
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
return self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
@@ -92,7 +78,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
@@ -103,8 +89,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.start_load_kv(self._get_connector_metadata())
|
||||
|
||||
@@ -113,8 +98,9 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
return
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
def save_kv_layer(
|
||||
self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
|
||||
) -> None:
|
||||
if not self.use_layerwise:
|
||||
return
|
||||
|
||||
@@ -133,17 +119,16 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
||||
|
||||
self.connector_worker.wait_for_save(self._get_connector_metadata())
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""Get the finished recving and sending requests."""
|
||||
assert self.connector_worker is not None
|
||||
done_sending, done_recving = self.connector_worker.get_finished(
|
||||
finished_req_ids, self._get_connector_metadata())
|
||||
finished_req_ids, self._get_connector_metadata()
|
||||
)
|
||||
return done_sending, done_recving
|
||||
|
||||
|
||||
class LookupKeyServer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool_worker: KVPoolWorker,
|
||||
@@ -171,8 +156,7 @@ class LookupKeyServer:
|
||||
token_len = int.from_bytes(all_frames[0], byteorder="big")
|
||||
hash_frames = all_frames[1:]
|
||||
hashes_str = self.decoder.decode(hash_frames)
|
||||
result = self.pool_worker.lookup_scheduler(
|
||||
token_len, hashes_str, self.use_layerwise)
|
||||
result = self.pool_worker.lookup_scheduler(token_len, hashes_str, self.use_layerwise)
|
||||
response = result.to_bytes(4, "big")
|
||||
self.socket.send(response)
|
||||
|
||||
|
||||
@@ -4,13 +4,15 @@ from vllm.config import ParallelConfig
|
||||
|
||||
|
||||
class Backend(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_device(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
||||
pass
|
||||
|
||||
@@ -19,11 +21,9 @@ class Backend(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
|
||||
pass
|
||||
|
||||
@@ -5,8 +5,7 @@ import torch
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
||||
Backend
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
|
||||
@@ -18,7 +17,6 @@ class MmcDirect(Enum):
|
||||
|
||||
|
||||
class MemcacheBackend(Backend):
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from memcache_hybrid import DistributedObjectStore # type: ignore
|
||||
@@ -26,21 +24,17 @@ class MemcacheBackend(Backend):
|
||||
raise ImportError(
|
||||
"Please install memcache by following the instructions at "
|
||||
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
|
||||
"to run vLLM with MemcacheConnector.") from e
|
||||
"to run vLLM with MemcacheConnector."
|
||||
) from e
|
||||
try:
|
||||
soc_version = get_ascend_device_type()
|
||||
if soc_version in {AscendDeviceType.A2}:
|
||||
import torch
|
||||
from vllm.distributed import get_world_group
|
||||
|
||||
tmp_tensor = torch.zeros(1, device="npu")
|
||||
output_tensor_list = [
|
||||
torch.empty_like(tmp_tensor)
|
||||
for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(
|
||||
output_tensor_list,
|
||||
tmp_tensor,
|
||||
group=get_world_group().device_group)
|
||||
output_tensor_list = [torch.empty_like(tmp_tensor) for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensor_list, tmp_tensor, group=get_world_group().device_group)
|
||||
self.rank = parallel_config.rank
|
||||
self.store = DistributedObjectStore()
|
||||
res = self.store.init(self.rank)
|
||||
@@ -54,8 +48,7 @@ class MemcacheBackend(Backend):
|
||||
logger.error("Configuration loading failed: %s", e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"An error occurred while loading the configuration: %s", exc)
|
||||
logger.error("An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
|
||||
def set_device(self):
|
||||
@@ -73,22 +66,18 @@ class MemcacheBackend(Backend):
|
||||
def exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def get(self, key: list[str], addr: list[list[int]],
|
||||
size: list[list[int]]):
|
||||
def get(self, key: list[str], addr: list[list[int]], size: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_get_into_layers(key, addr, size,
|
||||
MmcDirect.COPY_G2L.value)
|
||||
res = self.store.batch_get_into_layers(key, addr, size, MmcDirect.COPY_G2L.value)
|
||||
for value in res:
|
||||
if value != 0:
|
||||
logger.error(f"Failed to get key {key},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key {key}. {e}")
|
||||
|
||||
def put(self, key: list[str], addr: list[list[int]],
|
||||
size: list[list[int]]):
|
||||
def put(self, key: list[str], addr: list[list[int]], size: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_put_from_layers(key, addr, size,
|
||||
MmcDirect.COPY_L2G.value)
|
||||
res = self.store.batch_put_from_layers(key, addr, size, MmcDirect.COPY_L2G.value)
|
||||
for value in res:
|
||||
if value != 0:
|
||||
logger.error(f"Failed to get key {key},res:{res}")
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
# Third Party
|
||||
from mooncake.store import ReplicateConfig # type: ignore
|
||||
@@ -13,17 +12,14 @@ from vllm.config import ParallelConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
||||
Backend
|
||||
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import \
|
||||
global_te
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
|
||||
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||
|
||||
|
||||
class MooncakeBackend(Backend):
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore # type: ignore
|
||||
@@ -31,23 +27,25 @@ class MooncakeBackend(Backend):
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
"to run vLLM with MooncakeConnector."
|
||||
) from e
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
self.store = MooncakeDistributedStore()
|
||||
self.rank = parallel_config.rank
|
||||
if self.config.protocol == "ascend":
|
||||
local_hostname = get_ip()
|
||||
transfer_engine = global_te.get_transfer_engine(local_hostname,
|
||||
device_name=None)
|
||||
self.local_seg = local_hostname + ":" + str(
|
||||
transfer_engine.get_rpc_port())
|
||||
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address,
|
||||
transfer_engine.get_engine())
|
||||
transfer_engine = global_te.get_transfer_engine(local_hostname, device_name=None)
|
||||
self.local_seg = local_hostname + ":" + str(transfer_engine.get_rpc_port())
|
||||
ret = self.store.setup(
|
||||
self.local_seg,
|
||||
self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address,
|
||||
transfer_engine.get_engine(),
|
||||
)
|
||||
if ret != 0:
|
||||
msg = "Initialize mooncake failed."
|
||||
logger.error(msg)
|
||||
@@ -63,25 +61,21 @@ class MooncakeBackend(Backend):
|
||||
def exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def put(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
|
||||
try:
|
||||
config = ReplicateConfig()
|
||||
config.preferred_segment = self.local_seg
|
||||
config.prefer_alloc_in_same_node = True
|
||||
res = self.store.batch_put_from_multi_buffers(
|
||||
keys, addrs, sizes, config)
|
||||
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes, config)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to put key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to put key {keys},error:{e}")
|
||||
|
||||
def get(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_get_into_multi_buffers(
|
||||
keys, addrs, sizes, True)
|
||||
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes, True)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to get key {keys}, res:{res}")
|
||||
@@ -92,7 +86,7 @@ class MooncakeBackend(Backend):
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
metadata_server: str
|
||||
global_segment_size: Union[int, str]
|
||||
global_segment_size: int | str
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
@@ -105,33 +99,32 @@ class MooncakeStoreConfig:
|
||||
return MooncakeStoreConfig(
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=_parse_global_segment_size(
|
||||
config.get("global_segment_size",
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE)),
|
||||
local_buffer_size=_parse_global_segment_size(
|
||||
config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)),
|
||||
config.get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE)
|
||||
),
|
||||
local_buffer_size=_parse_global_segment_size(config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)),
|
||||
protocol=config.get("protocol", "ascend"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"))
|
||||
master_server_address=config.get("master_server_address"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeStoreConfig":
|
||||
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_path)
|
||||
|
||||
|
||||
def _parse_global_segment_size(value) -> int:
|
||||
"""
|
||||
Parse storage size strings with support for units: GB, MB, KB, B
|
||||
|
||||
|
||||
Args:
|
||||
value: Input value (int, str, or other convertible types)
|
||||
|
||||
|
||||
Returns:
|
||||
int: Size in bytes
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid format, missing number, or negative values
|
||||
TypeError: For unsupported input types
|
||||
@@ -143,54 +136,50 @@ def _parse_global_segment_size(value) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise TypeError(
|
||||
f"Unsupported type for global_segment_size: {type(value)}"
|
||||
) from e
|
||||
raise TypeError(f"Unsupported type for global_segment_size: {type(value)}") from e
|
||||
|
||||
cleaned_input = value.strip().lower()
|
||||
if not cleaned_input:
|
||||
raise ValueError("global segment size cannot be empty.")
|
||||
|
||||
UNIT_MULTIPLIERS = {
|
||||
'gb': 1024**3, # 1 GB = 1024^3 bytes
|
||||
'mb': 1024**2, # 1 MB = 1024^2 bytes
|
||||
'kb': 1024, # 1 KB = 1024 bytes
|
||||
'b': 1 # 1 B = 1 byte
|
||||
"gb": 1024**3, # 1 GB = 1024^3 bytes
|
||||
"mb": 1024**2, # 1 MB = 1024^2 bytes
|
||||
"kb": 1024, # 1 KB = 1024 bytes
|
||||
"b": 1, # 1 B = 1 byte
|
||||
}
|
||||
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
|
||||
pattern = r"^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$"
|
||||
match = re.match(pattern, cleaned_input)
|
||||
|
||||
if not match:
|
||||
raise ValueError(f"Invalid format: '{value}'")
|
||||
|
||||
number_str = match.group(1)
|
||||
unit = match.group(2) or 'b'
|
||||
unit = match.group(2) or "b"
|
||||
|
||||
multiplier = UNIT_MULTIPLIERS[unit]
|
||||
return _convert_to_bytes(number_str, multiplier, value)
|
||||
|
||||
|
||||
def _convert_to_bytes(number_str: str, multiplier: int,
|
||||
original_input: str) -> int:
|
||||
def _convert_to_bytes(number_str: str, multiplier: int, original_input: str) -> int:
|
||||
"""
|
||||
Convert numeric string to byte count
|
||||
|
||||
|
||||
Args:
|
||||
number_str: Numeric portion of input
|
||||
multiplier: Unit conversion factor
|
||||
original_input: Original input string (for error messages)
|
||||
|
||||
|
||||
Returns:
|
||||
int: Byte count
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid numbers or negative results
|
||||
"""
|
||||
try:
|
||||
numeric_value = float(number_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
||||
raise ValueError(f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
||||
# Calculate byte count
|
||||
try:
|
||||
byte_count = int(numeric_value * multiplier)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
|
||||
|
||||
#Parameters related to the key
|
||||
# Parameters related to the key
|
||||
@dataclass
|
||||
class KeyMetadata:
|
||||
"""name of the LLM model"""
|
||||
@@ -32,23 +32,26 @@ class PoolKey:
|
||||
chunk_hash: str
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.key_metadata.model_name,
|
||||
self.key_metadata.head_or_tp_rank,
|
||||
self.key_metadata.pcp_rank,
|
||||
self.key_metadata.dcp_rank,
|
||||
self.key_metadata.pp_rank,
|
||||
self.chunk_hash,
|
||||
))
|
||||
return hash(
|
||||
(
|
||||
self.key_metadata.model_name,
|
||||
self.key_metadata.head_or_tp_rank,
|
||||
self.key_metadata.pcp_rank,
|
||||
self.key_metadata.dcp_rank,
|
||||
self.key_metadata.pp_rank,
|
||||
self.chunk_hash,
|
||||
)
|
||||
)
|
||||
|
||||
def to_string(self):
|
||||
return (
|
||||
f"{self.key_metadata.model_name}"
|
||||
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
|
||||
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}"
|
||||
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}")
|
||||
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}"
|
||||
)
|
||||
|
||||
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
|
||||
def split_layers(self, num_layers: int) -> list["LayerPoolKey"]:
|
||||
"""Split the key into multiple keys for each layer"""
|
||||
keys = []
|
||||
for layer_id in range(num_layers):
|
||||
@@ -57,7 +60,8 @@ class PoolKey:
|
||||
self.key_metadata,
|
||||
self.chunk_hash,
|
||||
layer_id,
|
||||
))
|
||||
)
|
||||
)
|
||||
return keys
|
||||
|
||||
|
||||
@@ -68,14 +72,16 @@ class LayerPoolKey(PoolKey):
|
||||
layer_id: int
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
self.key_metadata.model_name,
|
||||
self.key_metadata.head_or_tp_rank,
|
||||
self.key_metadata.pcp_rank,
|
||||
self.key_metadata.dcp_rank,
|
||||
self.chunk_hash,
|
||||
self.layer_id,
|
||||
))
|
||||
return hash(
|
||||
(
|
||||
self.key_metadata.model_name,
|
||||
self.key_metadata.head_or_tp_rank,
|
||||
self.key_metadata.pcp_rank,
|
||||
self.key_metadata.dcp_rank,
|
||||
self.chunk_hash,
|
||||
self.layer_id,
|
||||
)
|
||||
)
|
||||
|
||||
def to_string(self):
|
||||
return (
|
||||
@@ -85,10 +91,8 @@ class LayerPoolKey(PoolKey):
|
||||
)
|
||||
|
||||
|
||||
class ChunkedTokenDatabase():
|
||||
|
||||
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
|
||||
partitions: Optional[List[int]]):
|
||||
class ChunkedTokenDatabase:
|
||||
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, partitions: list[int] | None):
|
||||
self.metadata = metadata
|
||||
self.block_size = block_size
|
||||
self.use_mla = use_mla
|
||||
@@ -96,9 +100,7 @@ class ChunkedTokenDatabase():
|
||||
self.block_len: list[int] = []
|
||||
self.partitions = partitions
|
||||
|
||||
def _make_key_by_hash(self,
|
||||
chunk_hash: str,
|
||||
layer_id: Optional[int] = None):
|
||||
def _make_key_by_hash(self, chunk_hash: str, layer_id: int | None = None):
|
||||
assert self.metadata is not None
|
||||
return PoolKey(
|
||||
self.metadata,
|
||||
@@ -116,8 +118,7 @@ class ChunkedTokenDatabase():
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = (self.block_len[index % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
block_len = self.block_len[index % 2] if self.use_mla else self.block_len[0]
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
@@ -125,22 +126,17 @@ class ChunkedTokenDatabase():
|
||||
size_list.append(length)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
|
||||
layer_id: int):
|
||||
def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int):
|
||||
block_id = block_ids[start // self.block_size]
|
||||
if self.use_mla:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[1]
|
||||
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[1]
|
||||
length_k = int(self.block_len[0] / self.block_size * (end - start))
|
||||
length_v = int(self.block_len[1] / self.block_size * (end - start))
|
||||
size_list = [length_k, length_v]
|
||||
else:
|
||||
addr_k = self.kv_caches_base_addr[layer_id *
|
||||
2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
||||
1] + block_id * self.block_len[0]
|
||||
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[0]
|
||||
length = int(self.block_len[0] / self.block_size * (end - start))
|
||||
size_list = [length, length]
|
||||
addr_list = [addr_k, addr_v]
|
||||
@@ -149,9 +145,9 @@ class ChunkedTokenDatabase():
|
||||
def process_tokens(
|
||||
self,
|
||||
token_len: int,
|
||||
block_hashes: Union[list[BlockHash], list[str]],
|
||||
block_hashes: list[BlockHash] | list[str],
|
||||
mask_num: int = 0,
|
||||
) -> Iterable[Tuple[int, int, PoolKey]]:
|
||||
) -> Iterable[tuple[int, int, PoolKey]]:
|
||||
"""Process the tokens and return the corresponding cache engine keys.
|
||||
|
||||
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
|
||||
@@ -202,10 +198,10 @@ class ChunkedTokenDatabase():
|
||||
start = 0
|
||||
for j, part in enumerate(self.partitions):
|
||||
# part * 2 because addr and size contain both k and v
|
||||
end = len(addr_list) if j == len(
|
||||
self.partitions) - 1 else start + part * 2
|
||||
end = len(addr_list) if j == len(self.partitions) - 1 else start + part * 2
|
||||
new_str = key[i].replace( # type: ignore[attr-defined]
|
||||
"@pp_rank:0", f"@pp_rank:{j}", 1)
|
||||
"@pp_rank:0", f"@pp_rank:{j}", 1
|
||||
)
|
||||
new_key.append(new_str)
|
||||
new_addr.append(addr_list[start:end])
|
||||
new_size.append(size_list[start:end])
|
||||
@@ -213,7 +209,7 @@ class ChunkedTokenDatabase():
|
||||
return new_key, new_addr, new_size
|
||||
|
||||
|
||||
#Parameters related to the connector metadata
|
||||
# Parameters related to the connector metadata
|
||||
@dataclass
|
||||
class LoadSpec:
|
||||
# Number of tokens cached in vLLM
|
||||
@@ -273,7 +269,7 @@ class RequestTracker:
|
||||
|
||||
def update(
|
||||
self,
|
||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
||||
new_block_ids: tuple[list[int], ...] | list[int],
|
||||
) -> None:
|
||||
"""Update the request tracker when a running request is
|
||||
scheduled again
|
||||
@@ -286,8 +282,7 @@ class RequestTracker:
|
||||
elif isinstance(new_block_ids, list):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported new_block_ids type {type(new_block_ids)}")
|
||||
raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}")
|
||||
self.allocated_block_ids.extend(new_block_ids)
|
||||
|
||||
|
||||
@@ -302,22 +297,22 @@ class ReqMeta:
|
||||
|
||||
block_hashes: list[BlockHash]
|
||||
|
||||
can_save: Optional[bool] = None
|
||||
can_save: bool | None = None
|
||||
# load_spec
|
||||
load_spec: Optional[LoadSpec] = None
|
||||
load_spec: LoadSpec | None = None
|
||||
|
||||
is_last_chunk: Optional[bool] = None
|
||||
is_last_chunk: bool | None = None
|
||||
|
||||
current_event: Optional[torch.npu.Event] = None
|
||||
current_event: torch.npu.Event | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_request_tracker(
|
||||
tracker: RequestTracker,
|
||||
block_size: int,
|
||||
load_spec: Optional[LoadSpec] = None,
|
||||
skip_save: Optional[bool] = False,
|
||||
block_hashes: list[BlockHash] = [],
|
||||
is_last_chunk: Optional[bool] = None,
|
||||
load_spec: LoadSpec | None = None,
|
||||
skip_save: bool | None = False,
|
||||
block_hashes: list[BlockHash] | None = None,
|
||||
is_last_chunk: bool | None = None,
|
||||
discard_partial_chunks: bool = True,
|
||||
) -> Optional["ReqMeta"]:
|
||||
"""Create the request metadata from a request tracker.
|
||||
@@ -333,17 +328,17 @@ class ReqMeta:
|
||||
the request metadata if we need to perform load/save
|
||||
operations, None otherwise.
|
||||
"""
|
||||
if block_hashes is None:
|
||||
block_hashes = []
|
||||
input_token_len = tracker.token_len
|
||||
|
||||
# For save operation: do not save if the following condition is met
|
||||
# 1. has already been saved before (num_saved_tokens > 0)
|
||||
# 2. number of unsaved tokens is not reached the chunk boundary
|
||||
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
|
||||
block_size if discard_partial_chunks else 0)
|
||||
chunk_boundary = cdiv(tracker.num_saved_tokens + 1, block_size) * block_size if discard_partial_chunks else 0
|
||||
# Calculate number of tokens to save based on discard_partial_chunks
|
||||
# setting
|
||||
num_tokens_to_save = ((input_token_len // block_size * block_size)
|
||||
if discard_partial_chunks else input_token_len)
|
||||
num_tokens_to_save = (input_token_len // block_size * block_size) if discard_partial_chunks else input_token_len
|
||||
|
||||
skip_save = skip_save or num_tokens_to_save < chunk_boundary
|
||||
if skip_save and load_spec is None:
|
||||
@@ -363,9 +358,7 @@ class ReqMeta:
|
||||
else:
|
||||
# Do not load if not in `can_load` state
|
||||
load_spec = None
|
||||
logger.debug(
|
||||
f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}"
|
||||
)
|
||||
logger.debug(f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}")
|
||||
return ReqMeta(
|
||||
req_id=tracker.req_id,
|
||||
token_len_chunk=num_tokens_to_save,
|
||||
@@ -378,7 +371,6 @@ class ReqMeta:
|
||||
|
||||
|
||||
class AscendConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self, unfinished_request_ids, preempted_req_ids):
|
||||
self.requests = []
|
||||
self.unfinished_request_ids = unfinished_request_ids
|
||||
@@ -396,10 +388,10 @@ class AscendConnectorMetadata(KVConnectorMetadata):
|
||||
@dataclass
|
||||
class LasyerMultiBlockReqMeta:
|
||||
req_id: str
|
||||
keys: List[LayerPoolKey]
|
||||
starts: List[int]
|
||||
keys: list[LayerPoolKey]
|
||||
starts: list[int]
|
||||
ends: list[int]
|
||||
block_ids: list[int]
|
||||
layer_id: int
|
||||
is_last_chunk: Optional[bool] = True
|
||||
current_event: Optional[torch.npu.Event] = None
|
||||
is_last_chunk: bool | None = True
|
||||
current_event: torch.npu.Event | None = None
|
||||
|
||||
@@ -7,8 +7,7 @@ from typing import Any
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
||||
Backend
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||||
@@ -20,10 +19,16 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import
|
||||
|
||||
|
||||
class KVTransferThread(threading.Thread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
block_size: int, tp_rank: int, dcp_size: int,
|
||||
ready_event: threading.Event, name: str):
|
||||
def __init__(
|
||||
self,
|
||||
m_store: Backend,
|
||||
token_database: ChunkedTokenDatabase,
|
||||
block_size: int,
|
||||
tp_rank: int,
|
||||
dcp_size: int,
|
||||
ready_event: threading.Event,
|
||||
name: str,
|
||||
):
|
||||
super().__init__(daemon=True, name=name)
|
||||
self.m_store = m_store
|
||||
self.ready_event = ready_event
|
||||
@@ -39,7 +44,7 @@ class KVTransferThread(threading.Thread):
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: ReqMeta,
|
||||
request: ReqMeta | LasyerMultiBlockReqMeta,
|
||||
) -> torch.Tensor:
|
||||
self.request_queue.put(request)
|
||||
|
||||
@@ -98,17 +103,20 @@ class KVTransferThread(threading.Thread):
|
||||
|
||||
|
||||
class KVCacheStoreSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
||||
kv_role: str, ready_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
name="KVCacheSendingThread")
|
||||
def __init__(
|
||||
self,
|
||||
m_store: Backend,
|
||||
token_database: ChunkedTokenDatabase,
|
||||
block_size: int,
|
||||
tp_rank: int,
|
||||
dcp_size: int,
|
||||
put_step: int,
|
||||
kv_role: str,
|
||||
ready_event: threading.Event,
|
||||
):
|
||||
super().__init__(
|
||||
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheSendingThread"
|
||||
)
|
||||
self.put_step = put_step
|
||||
self.kv_role = kv_role
|
||||
self.stored_requests = defaultdict[str, int](int)
|
||||
@@ -139,16 +147,15 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
self.request_queue.task_done()
|
||||
return
|
||||
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, req_meta.block_hashes):
|
||||
for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes):
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(key.to_string())
|
||||
|
||||
if not self.dcp_size > 1:
|
||||
starts = starts[self.tp_rank % self.put_step::self.put_step]
|
||||
ends = ends[self.tp_rank % self.put_step::self.put_step]
|
||||
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
||||
starts = starts[self.tp_rank % self.put_step :: self.put_step]
|
||||
ends = ends[self.tp_rank % self.put_step :: self.put_step]
|
||||
keys = keys[self.tp_rank % self.put_step :: self.put_step]
|
||||
|
||||
if not keys:
|
||||
self.dec_stored_request(req_id)
|
||||
@@ -165,8 +172,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
keys = keys[skip_block_num:]
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d blocks "
|
||||
"(skip_block_num=%d) for request %s",
|
||||
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
|
||||
len(keys),
|
||||
token_len // self.block_size,
|
||||
skip_block_num,
|
||||
@@ -183,14 +189,12 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
addrs = []
|
||||
sizes = []
|
||||
for index, start in enumerate(starts):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, ends[index], block_ids)
|
||||
addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
|
||||
addrs.append(addr)
|
||||
sizes.append(size)
|
||||
|
||||
if self.kv_role == "kv_consumer":
|
||||
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(
|
||||
keys, addrs, sizes)
|
||||
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(keys, addrs, sizes)
|
||||
|
||||
if current_event is not None:
|
||||
current_event.synchronize()
|
||||
@@ -201,69 +205,69 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
||||
|
||||
|
||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
block_size: int, tp_rank: int, dcp_size: int,
|
||||
ready_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreRecvingThread")
|
||||
def __init__(
|
||||
self,
|
||||
m_store: Backend,
|
||||
token_database: ChunkedTokenDatabase,
|
||||
block_size: int,
|
||||
tp_rank: int,
|
||||
dcp_size: int,
|
||||
ready_event: threading.Event,
|
||||
):
|
||||
super().__init__(
|
||||
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreRecvingThread"
|
||||
)
|
||||
|
||||
def _handle_request(self, req_meta: ReqMeta):
|
||||
token_len = req_meta.load_spec.token_len # type: ignore[union-attr]
|
||||
req_id = req_meta.req_id
|
||||
mask_num = (
|
||||
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
||||
// self.block_size * self.block_size)
|
||||
// self.block_size
|
||||
* self.block_size
|
||||
)
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, req_meta.block_hashes, mask_num):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, end, req_meta.block_ids)
|
||||
for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes, mask_num):
|
||||
addr, size, _ = self.token_database.prepare_value(start, end, req_meta.block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
key_list_c = key_list[self.tp_rank %
|
||||
len(key_list):] + key_list[:self.tp_rank %
|
||||
len(key_list)]
|
||||
addr_list_c = addr_list[self.tp_rank %
|
||||
len(addr_list):] + addr_list[:self.tp_rank %
|
||||
len(addr_list)]
|
||||
size_list_c = size_list[self.tp_rank %
|
||||
len(size_list):] + size_list[:self.tp_rank %
|
||||
len(size_list)]
|
||||
key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
|
||||
addr_list_c = addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
|
||||
size_list_c = size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
|
||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||
self.set_finished_request(req_id)
|
||||
self.request_queue.task_done()
|
||||
|
||||
|
||||
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
||||
ready_event: threading.Event, num_layers: int):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerSendingThread")
|
||||
def __init__(
|
||||
self,
|
||||
m_store: Backend,
|
||||
token_database: ChunkedTokenDatabase,
|
||||
block_size: int,
|
||||
tp_rank: int,
|
||||
dcp_size: int,
|
||||
put_step: int,
|
||||
ready_event: threading.Event,
|
||||
num_layers: int,
|
||||
):
|
||||
super().__init__(
|
||||
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerSendingThread"
|
||||
)
|
||||
self.final_layer_id = num_layers - 1
|
||||
self.put_step = put_step
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: ReqMeta) -> torch.Tensor:
|
||||
self, req_meta: ReqMeta
|
||||
) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
self, req_meta: LasyerMultiBlockReqMeta
|
||||
):
|
||||
starts = req_meta.starts
|
||||
ends = req_meta.ends
|
||||
keys = req_meta.keys
|
||||
@@ -272,9 +276,9 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
total_block = len(keys)
|
||||
is_last_chunk = req_meta.is_last_chunk
|
||||
if not self.dcp_size > 1:
|
||||
starts = starts[self.tp_rank % self.put_step::self.put_step]
|
||||
ends = ends[self.tp_rank % self.put_step::self.put_step]
|
||||
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
||||
starts = starts[self.tp_rank % self.put_step :: self.put_step]
|
||||
ends = ends[self.tp_rank % self.put_step :: self.put_step]
|
||||
keys = keys[self.tp_rank % self.put_step :: self.put_step]
|
||||
|
||||
if not keys:
|
||||
if is_last_chunk:
|
||||
@@ -300,7 +304,8 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
size_list = []
|
||||
for index, key in enumerate(key_list):
|
||||
addr, size = self.token_database.prepare_value_layer(
|
||||
starts[index], ends[index], req_meta.block_ids, layer_id)
|
||||
starts[index], ends[index], req_meta.block_ids, layer_id
|
||||
)
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
|
||||
@@ -313,8 +318,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
self.request_queue.task_done()
|
||||
|
||||
logger.info(
|
||||
"Storing KV cache for %d out of %d blocks "
|
||||
"(skip_block_num=%d) for request %s",
|
||||
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
|
||||
len(keys),
|
||||
total_block,
|
||||
skip_block_num,
|
||||
@@ -323,44 +327,42 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||
|
||||
|
||||
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
||||
|
||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
||||
block_size: int, tp_rank: int, dcp_size: int,
|
||||
ready_event: threading.Event, get_event: threading.Event):
|
||||
super().__init__(m_store,
|
||||
token_database,
|
||||
block_size,
|
||||
tp_rank,
|
||||
dcp_size,
|
||||
ready_event,
|
||||
name="KVCacheStoreLayerRecvingThread")
|
||||
def __init__(
|
||||
self,
|
||||
m_store: Backend,
|
||||
token_database: ChunkedTokenDatabase,
|
||||
block_size: int,
|
||||
tp_rank: int,
|
||||
dcp_size: int,
|
||||
ready_event: threading.Event,
|
||||
get_event: threading.Event,
|
||||
):
|
||||
super().__init__(
|
||||
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerRecvingThread"
|
||||
)
|
||||
self.get_event = get_event
|
||||
|
||||
def add_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
||||
self, req_meta: LasyerMultiBlockReqMeta
|
||||
) -> torch.Tensor:
|
||||
self.request_queue.put(req_meta)
|
||||
|
||||
def _handle_request( # type: ignore[override]
|
||||
self, req_meta: LasyerMultiBlockReqMeta):
|
||||
self, req_meta: LasyerMultiBlockReqMeta
|
||||
):
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
for index, key in enumerate(req_meta.keys):
|
||||
addr, size = self.token_database.prepare_value_layer(
|
||||
req_meta.starts[index], req_meta.ends[index],
|
||||
req_meta.block_ids, req_meta.layer_id)
|
||||
req_meta.starts[index], req_meta.ends[index], req_meta.block_ids, req_meta.layer_id
|
||||
)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
key_list_c = key_list[self.tp_rank %
|
||||
len(key_list):] + key_list[:self.tp_rank %
|
||||
len(key_list)]
|
||||
addr_list_c = addr_list[self.tp_rank %
|
||||
len(addr_list):] + addr_list[:self.tp_rank %
|
||||
len(addr_list)]
|
||||
size_list_c = size_list[self.tp_rank %
|
||||
len(size_list):] + size_list[:self.tp_rank %
|
||||
len(size_list)]
|
||||
key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
|
||||
addr_list_c = addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
|
||||
size_list_c = size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
|
||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||
|
||||
self.request_queue.task_done()
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import vllm.envs as envs
|
||||
import zmq
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
@@ -14,27 +13,29 @@ from vllm.v1.request import Request
|
||||
from vllm.v1.serial_utils import MsgpackEncoder
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||||
AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker)
|
||||
AscendConnectorMetadata,
|
||||
LoadSpec,
|
||||
ReqMeta,
|
||||
RequestTracker,
|
||||
)
|
||||
|
||||
|
||||
class KVPoolScheduler:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
|
||||
self.use_layerwise = use_layerwise
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_load", False)
|
||||
"consumer_is_to_load", False
|
||||
)
|
||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_put", False)
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
"consumer_is_to_put", False
|
||||
)
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get("load_async", False)
|
||||
self.client = LookupKeyClient(vllm_config)
|
||||
# request_id -> (vllm cached tokes, kvpool cached tokens)
|
||||
self.load_specs: dict[str, LoadSpec] = {}
|
||||
self.pcp_size = getattr(vllm_config.parallel_config,
|
||||
"prefill_context_parallel_size", 1)
|
||||
self.dcp_size = getattr(vllm_config.parallel_config,
|
||||
"decode_context_parallel_size", 1)
|
||||
self.pcp_size = getattr(vllm_config.parallel_config, "prefill_context_parallel_size", 1)
|
||||
self.dcp_size = getattr(vllm_config.parallel_config, "decode_context_parallel_size", 1)
|
||||
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
if self.pcp_size > 1:
|
||||
@@ -45,9 +46,9 @@ class KVPoolScheduler:
|
||||
self._request_trackers: dict[str, RequestTracker] = {}
|
||||
self._preempted_req_ids: set[str] = set()
|
||||
# Whether to discard partial chunks
|
||||
self._discard_partial_chunks = (
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"discard_partial_chunks", True))
|
||||
self._discard_partial_chunks = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"discard_partial_chunks", True
|
||||
)
|
||||
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._unfinished_request_ids: set[str] = set()
|
||||
|
||||
@@ -72,13 +73,11 @@ class KVPoolScheduler:
|
||||
return 0, False
|
||||
|
||||
if self._discard_partial_chunks:
|
||||
token_len = len(request.prompt_token_ids
|
||||
) // self._block_size * self._block_size
|
||||
token_len = len(request.prompt_token_ids) // self._block_size * self._block_size
|
||||
else:
|
||||
token_len = len(request.prompt_token_ids)
|
||||
|
||||
num_external_hit_tokens = self.client.lookup(token_len,
|
||||
request.block_hashes)
|
||||
num_external_hit_tokens = self.client.lookup(token_len, request.block_hashes)
|
||||
|
||||
if num_external_hit_tokens == request.num_tokens:
|
||||
num_external_hit_tokens -= 1
|
||||
@@ -107,9 +106,7 @@ class KVPoolScheduler:
|
||||
|
||||
return need_to_allocate, self.load_async and not self.use_layerwise
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after temporary buffer alloc.
|
||||
|
||||
@@ -120,8 +117,7 @@ class KVPoolScheduler:
|
||||
if num_external_tokens > 0:
|
||||
local_block_ids = blocks.get_block_ids()[0]
|
||||
|
||||
self._unfinished_requests[request.request_id] = (request,
|
||||
local_block_ids)
|
||||
self._unfinished_requests[request.request_id] = (request, local_block_ids)
|
||||
self._unfinished_request_ids.add(request.request_id)
|
||||
if request.request_id not in self.load_specs:
|
||||
# No KV tokens from external KV cache, return
|
||||
@@ -133,18 +129,20 @@ class KVPoolScheduler:
|
||||
return
|
||||
|
||||
assert (
|
||||
num_external_tokens > 0 and num_external_tokens
|
||||
== self.load_specs[request.request_id].kvpool_cached_tokens -
|
||||
self.load_specs[request.request_id].vllm_cached_tokens
|
||||
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
|
||||
num_external_tokens > 0
|
||||
and num_external_tokens
|
||||
== self.load_specs[request.request_id].kvpool_cached_tokens
|
||||
- self.load_specs[request.request_id].vllm_cached_tokens
|
||||
), (
|
||||
f"Mismatch in number of tokens: {num_external_tokens} vs "
|
||||
f"{self.load_specs[request.request_id].kvpool_cached_tokens} - "
|
||||
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
|
||||
f" for request {request.request_id}")
|
||||
f" for request {request.request_id}"
|
||||
)
|
||||
|
||||
self.load_specs[request.request_id].can_load = True
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""Attach the connector metadata to the request object.
|
||||
|
||||
This function should NOT modify other fields in the scheduler_output
|
||||
@@ -155,14 +153,13 @@ class KVPoolScheduler:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
force_skip_save = (self.kv_role == "kv_consumer"
|
||||
and not self.consumer_is_to_put)
|
||||
force_skip_save = self.kv_role == "kv_consumer" and not self.consumer_is_to_put
|
||||
|
||||
for finished_req_id in scheduler_output.finished_req_ids:
|
||||
self._request_trackers.pop(finished_req_id, None)
|
||||
self._unfinished_requests.pop(finished_req_id, None)
|
||||
self._unfinished_request_ids.discard(finished_req_id)
|
||||
|
||||
|
||||
for req_id in scheduler_output.preempted_req_ids:
|
||||
self._preempted_req_ids.update(scheduler_output.preempted_req_ids)
|
||||
self._request_trackers.pop(req_id, None)
|
||||
@@ -173,9 +170,7 @@ class KVPoolScheduler:
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
# Right now, we only load KV for new requests
|
||||
load_spec = self.load_specs.pop(request.req_id, None)
|
||||
num_tokens_to_compute = (
|
||||
request.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[request.req_id])
|
||||
num_tokens_to_compute = request.num_computed_tokens + scheduler_output.num_scheduled_tokens[request.req_id]
|
||||
request_tuple = self._unfinished_requests.get(request.req_id)
|
||||
request_real = request_tuple[0] # type: ignore[index]
|
||||
if not isinstance(request.block_ids[0], list):
|
||||
@@ -183,25 +178,25 @@ class KVPoolScheduler:
|
||||
else:
|
||||
unfolded_block_ids = request.block_ids[0].copy()
|
||||
request_tracker = RequestTracker(
|
||||
req_id=request.req_id,
|
||||
token_len=num_tokens_to_compute,
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
req_id=request.req_id,
|
||||
token_len=num_tokens_to_compute,
|
||||
allocated_block_ids=unfolded_block_ids,
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
self._request_trackers[request.req_id] = request_tracker
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else len(
|
||||
request.prompt_token_ids))
|
||||
|
||||
last_chunk_tokens_num = (
|
||||
(len(request.prompt_token_ids) // self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks
|
||||
else len(request.prompt_token_ids)
|
||||
)
|
||||
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=force_skip_save,
|
||||
block_hashes=request_real.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len
|
||||
>= last_chunk_tokens_num,
|
||||
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
@@ -224,8 +219,8 @@ class KVPoolScheduler:
|
||||
request_tuple = self._unfinished_requests.get(req_id)
|
||||
request_real = request_tuple[0] # type: ignore[index]
|
||||
num_tokens_to_compute = (
|
||||
request_real.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
request_real.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]
|
||||
)
|
||||
request_tracker = RequestTracker(
|
||||
req_id=req_id,
|
||||
token_len=num_tokens_to_compute,
|
||||
@@ -233,21 +228,21 @@ class KVPoolScheduler:
|
||||
num_saved_tokens=0,
|
||||
)
|
||||
self._request_trackers[req_id] = request_tracker
|
||||
last_chunk_tokens_num = ((len(request_real.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else len(
|
||||
request_real.prompt_token_ids))
|
||||
last_chunk_tokens_num = (
|
||||
(len(request_real.prompt_token_ids) // self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks
|
||||
else len(request_real.prompt_token_ids)
|
||||
)
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=load_spec,
|
||||
skip_save=force_skip_save,
|
||||
block_hashes=request_real.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len
|
||||
>= last_chunk_tokens_num,
|
||||
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
|
||||
|
||||
# decode/chunked request
|
||||
else:
|
||||
request_tracker = self._request_trackers[req_id]
|
||||
@@ -256,48 +251,44 @@ class KVPoolScheduler:
|
||||
if req_tuple:
|
||||
request = req_tuple[0]
|
||||
num_current_tokens = request_tracker.token_len
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
||||
new_token_ids = request.all_token_ids[num_current_tokens : num_current_tokens + num_new_tokens]
|
||||
request_tracker.token_len += len(new_token_ids)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Request {req_id} is not in _unfinished_requests, "
|
||||
f"but it is scheduled to be cached")
|
||||
f"Request {req_id} is not in _unfinished_requests, but it is scheduled to be cached"
|
||||
)
|
||||
num_computed_token = cached_reqs.num_computed_tokens[i]
|
||||
if num_computed_token >= len(request.prompt_token_ids):
|
||||
continue
|
||||
request_tracker.update(new_block_ids)
|
||||
|
||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
||||
self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks else
|
||||
len(request.prompt_token_ids))
|
||||
last_chunk_tokens_num = (
|
||||
(len(request.prompt_token_ids) // self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks
|
||||
else len(request.prompt_token_ids)
|
||||
)
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
self._block_size,
|
||||
load_spec=None,
|
||||
skip_save=force_skip_save,
|
||||
block_hashes=request.block_hashes,
|
||||
is_last_chunk=request_tracker.token_len
|
||||
>= last_chunk_tokens_num,
|
||||
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
request_ids = [
|
||||
req.req_id for req in scheduler_output.scheduled_new_reqs
|
||||
]
|
||||
for request_id, (request,
|
||||
block_ids) in self._unfinished_requests.items():
|
||||
request_ids = [req.req_id for req in scheduler_output.scheduled_new_reqs]
|
||||
for request_id, (request, block_ids) in self._unfinished_requests.items():
|
||||
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
|
||||
load_spec = self.load_specs.pop(request_id, None)
|
||||
if not load_spec:
|
||||
continue
|
||||
num_tokens_to_compute = load_spec.kvpool_cached_tokens
|
||||
if (num_tokens_to_compute % self._block_size
|
||||
!= 0) and (num_tokens_to_compute
|
||||
== len(request.prompt_token_ids) - 1):
|
||||
if (num_tokens_to_compute % self._block_size != 0) and (
|
||||
num_tokens_to_compute == len(request.prompt_token_ids) - 1
|
||||
):
|
||||
num_tokens_to_compute = num_tokens_to_compute + 1
|
||||
request_tracker = RequestTracker(
|
||||
req_id=request_id,
|
||||
@@ -324,7 +315,7 @@ class KVPoolScheduler:
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
@@ -336,13 +327,11 @@ class KVPoolScheduler:
|
||||
return False, None
|
||||
delay_free_blocks = len(block_ids) > 0
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(block_ids), request.request_id)
|
||||
logger.info("Delaying free of %d blocks for request %s", len(block_ids), request.request_id)
|
||||
return delay_free_blocks, None
|
||||
|
||||
|
||||
class LookupKeyClient:
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
self.encoder = MsgpackEncoder()
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
@@ -1,37 +1,45 @@
|
||||
import math
|
||||
import threading
|
||||
from typing import Dict, Generator, Optional, Type
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size,
|
||||
get_pcp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed import (
|
||||
get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size,
|
||||
get_pcp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
||||
Backend
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import \
|
||||
MemcacheBackend
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import \
|
||||
MooncakeBackend
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import MemcacheBackend
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import MooncakeBackend
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||||
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
|
||||
LasyerMultiBlockReqMeta, ReqMeta)
|
||||
AscendConnectorMetadata,
|
||||
ChunkedTokenDatabase,
|
||||
KeyMetadata,
|
||||
LasyerMultiBlockReqMeta,
|
||||
ReqMeta,
|
||||
)
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
|
||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||
KVCacheStoreLayerRecvingThread,
|
||||
KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread,
|
||||
KVCacheStoreSendingThread,
|
||||
KVTransferThread,
|
||||
)
|
||||
|
||||
backend_map: Dict[str, Type[Backend]] = {
|
||||
backend_map: dict[str, Callable[..., Backend]] = {
|
||||
"mooncake": MooncakeBackend,
|
||||
"memcache": MemcacheBackend,
|
||||
}
|
||||
|
||||
|
||||
class KVPoolWorker:
|
||||
#The main class for the cache engine.
|
||||
# The main class for the cache engine.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -42,9 +50,7 @@ class KVPoolWorker:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.dp_rank = parallel_config.data_parallel_rank
|
||||
self.use_mla = False
|
||||
if (hasattr(model_config, "use_mla")
|
||||
and isinstance(model_config.use_mla, bool)
|
||||
and model_config.use_mla):
|
||||
if hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla:
|
||||
self.use_mla = True
|
||||
self.use_layerwise = use_layerwize
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
@@ -53,19 +59,16 @@ class KVPoolWorker:
|
||||
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size
|
||||
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
|
||||
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"load_async", False)
|
||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get("load_async", False)
|
||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"consumer_is_to_put", False)
|
||||
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"backend", "mooncake")
|
||||
"consumer_is_to_put", False
|
||||
)
|
||||
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get("backend", "mooncake")
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
if self.pcp_size > 1:
|
||||
@@ -88,7 +91,7 @@ class KVPoolWorker:
|
||||
self.put_step = 1
|
||||
|
||||
self.metadata = KeyMetadata(
|
||||
model_config.model.rstrip('/').split('/')[-1],
|
||||
model_config.model.rstrip("/").split("/")[-1],
|
||||
self.head_or_tp_rank,
|
||||
self.pcp_rank,
|
||||
self.dcp_rank,
|
||||
@@ -99,40 +102,28 @@ class KVPoolWorker:
|
||||
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
|
||||
num_hidden_layers = model_config.hf_text_config.num_hidden_layers
|
||||
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"prefill_pp_layer_partition", None)
|
||||
prefill_pp_size = int(
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"prefill_pp_size", 1))
|
||||
"prefill_pp_layer_partition", None
|
||||
)
|
||||
prefill_pp_size = int(vllm_config.kv_transfer_config.kv_connector_extra_config.get("prefill_pp_size", 1))
|
||||
|
||||
if partition_list_str is not None:
|
||||
try:
|
||||
partitions = [
|
||||
int(layer) for layer in partition_list_str.split(",")
|
||||
]
|
||||
partitions = [int(layer) for layer in partition_list_str.split(",")]
|
||||
except ValueError as err:
|
||||
raise ValueError("Invalid partition string: {}".format(
|
||||
partition_list_str)) from err
|
||||
raise ValueError("Invalid partition string: {}".format(partition_list_str)) from err
|
||||
if len(partitions) != prefill_pp_size:
|
||||
raise ValueError(
|
||||
f"{len(partitions)=} does not match {prefill_pp_size=}."
|
||||
)
|
||||
raise ValueError(f"{len(partitions)=} does not match {prefill_pp_size=}.")
|
||||
if sum(partitions) != num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"{sum(partitions)=} does not match {num_hidden_layers=}."
|
||||
)
|
||||
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
||||
else:
|
||||
layers_per_partition = num_hidden_layers // prefill_pp_size
|
||||
partitions = [
|
||||
layers_per_partition for _ in range(prefill_pp_size)
|
||||
]
|
||||
partitions = [layers_per_partition for _ in range(prefill_pp_size)]
|
||||
|
||||
if remaining_layers := num_hidden_layers % prefill_pp_size:
|
||||
for i in range(2, remaining_layers + 2):
|
||||
partitions[-i] += 1
|
||||
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata,
|
||||
self.block_size,
|
||||
self.use_mla, partitions)
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions)
|
||||
|
||||
real_backend = backend_map.get(self.backend.lower())
|
||||
|
||||
@@ -142,10 +133,11 @@ class KVPoolWorker:
|
||||
self.put_step = 1
|
||||
|
||||
self.m_store = real_backend( # type: ignore[misc]
|
||||
parallel_config)
|
||||
parallel_config
|
||||
)
|
||||
|
||||
self.kv_send_thread: Optional[KVTransferThread] = None
|
||||
self.kv_recv_thread: Optional[KVTransferThread] = None
|
||||
self.kv_send_thread: KVTransferThread | None = None
|
||||
self.kv_recv_thread: KVTransferThread | None = None
|
||||
|
||||
self.finished_store_req: set[str] = set()
|
||||
|
||||
@@ -162,11 +154,14 @@ class KVPoolWorker:
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
self.num_blocks,
|
||||
block_shape_norm,
|
||||
block_shape_pe,
|
||||
)
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -174,11 +169,9 @@ class KVPoolWorker:
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
self.kv_caches_base_addr = []
|
||||
@@ -194,8 +187,7 @@ class KVPoolWorker:
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
@@ -208,33 +200,50 @@ class KVPoolWorker:
|
||||
|
||||
if self.use_layerwise:
|
||||
self.get_event = threading.Event()
|
||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
||||
if self.kv_role in ["kv_producer", "kv_both"]:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, self.put_step,
|
||||
ready_event_sending, self.num_layers)
|
||||
self.m_store,
|
||||
self.token_database,
|
||||
self.block_size,
|
||||
self.tp_rank,
|
||||
self.dcp_size,
|
||||
self.put_step,
|
||||
ready_event_sending,
|
||||
self.num_layers,
|
||||
)
|
||||
self.kv_send_thread.start()
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, ready_event, self.get_event)
|
||||
self.m_store,
|
||||
self.token_database,
|
||||
self.block_size,
|
||||
self.tp_rank,
|
||||
self.dcp_size,
|
||||
ready_event,
|
||||
self.get_event,
|
||||
)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
else:
|
||||
if self.kv_role in ['kv_producer', 'kv_both'
|
||||
] or self.consumer_is_to_put:
|
||||
if self.kv_role in ["kv_producer", "kv_both"] or self.consumer_is_to_put:
|
||||
ready_event_sending = threading.Event()
|
||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
|
||||
ready_event_sending)
|
||||
self.m_store,
|
||||
self.token_database,
|
||||
self.block_size,
|
||||
self.tp_rank,
|
||||
self.dcp_size,
|
||||
self.put_step,
|
||||
self.kv_role,
|
||||
ready_event_sending,
|
||||
)
|
||||
self.kv_send_thread.start()
|
||||
if self.load_async:
|
||||
ready_event = threading.Event()
|
||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||
self.m_store, self.token_database, self.block_size,
|
||||
self.tp_rank, self.dcp_size, ready_event)
|
||||
self.m_store, self.token_database, self.block_size, self.tp_rank, self.dcp_size, ready_event
|
||||
)
|
||||
self.kv_recv_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
@@ -243,12 +252,12 @@ class KVPoolWorker:
|
||||
self.layerwise_retrievers = []
|
||||
for request in metadata.requests:
|
||||
load_spec = request.load_spec
|
||||
if load_spec is None or not load_spec.can_load: #load =0
|
||||
if load_spec is None or not load_spec.can_load: # load =0
|
||||
continue
|
||||
token_len = request.token_len_chunk
|
||||
if (load_spec.kvpool_cached_tokens % self.block_size
|
||||
!= 0) and (load_spec.kvpool_cached_tokens
|
||||
== token_len - 1):
|
||||
if (load_spec.kvpool_cached_tokens % self.block_size != 0) and (
|
||||
load_spec.kvpool_cached_tokens == token_len - 1
|
||||
):
|
||||
token_len = request.load_spec.kvpool_cached_tokens + 1
|
||||
else:
|
||||
token_len = request.load_spec.kvpool_cached_tokens
|
||||
@@ -260,30 +269,27 @@ class KVPoolWorker:
|
||||
else:
|
||||
if self.load_async:
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
||||
request, )
|
||||
request,
|
||||
)
|
||||
else:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
key_list = []
|
||||
mask_num = (request.load_spec.vllm_cached_tokens //
|
||||
self.block_size * self.block_size)
|
||||
mask_num = request.load_spec.vllm_cached_tokens // self.block_size * self.block_size
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, request.block_hashes, mask_num):
|
||||
addr, size, _ = self.token_database.prepare_value(
|
||||
start, end, request.block_ids)
|
||||
token_len, request.block_hashes, mask_num
|
||||
):
|
||||
addr, size, _ = self.token_database.prepare_value(start, end, request.block_ids)
|
||||
key_list.append(key.to_string())
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
key_list_c = key_list[self.tp_rank % len(
|
||||
key_list):] + key_list[:self.tp_rank % len(key_list)]
|
||||
addr_list_c = addr_list[self.tp_rank %
|
||||
len(addr_list
|
||||
):] + addr_list[:self.tp_rank %
|
||||
len(addr_list)]
|
||||
size_list_c = size_list[self.tp_rank %
|
||||
len(size_list
|
||||
):] + size_list[:self.tp_rank %
|
||||
len(size_list)]
|
||||
key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
|
||||
addr_list_c = (
|
||||
addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
|
||||
)
|
||||
size_list_c = (
|
||||
size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
|
||||
)
|
||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
@@ -294,8 +300,7 @@ class KVPoolWorker:
|
||||
num_retrieved_tokens = ret_token_mask.sum().item()
|
||||
logger.debug(f"Retrieved {num_retrieved_tokens} tokens")
|
||||
|
||||
def save_kv_layer(self,
|
||||
connector_metadata: AscendConnectorMetadata) -> None:
|
||||
def save_kv_layer(self, connector_metadata: AscendConnectorMetadata) -> None:
|
||||
if self.current_layer == 0:
|
||||
self.layerwise_storers = []
|
||||
current_event = None
|
||||
@@ -336,15 +341,17 @@ class KVPoolWorker:
|
||||
continue
|
||||
|
||||
request.current_event = current_event
|
||||
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
|
||||
request.req_id)
|
||||
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
|
||||
request.req_id
|
||||
)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||
request, )
|
||||
request,
|
||||
)
|
||||
|
||||
def retrieve_layer(
|
||||
self,
|
||||
request: ReqMeta,
|
||||
) -> Generator[Optional[torch.Tensor], None, None]:
|
||||
) -> Generator[torch.Tensor | None, None, None]:
|
||||
"""
|
||||
Retrieve the KV cache in a layerwise manner.
|
||||
|
||||
@@ -359,12 +366,14 @@ class KVPoolWorker:
|
||||
|
||||
return: A generator that yields Optional[torch.Tensor]. The tensor will
|
||||
be the boolean mask indicating which tokens are retrieved and will
|
||||
only be returned in the last iteration.
|
||||
only be returned in the last iteration.
|
||||
"""
|
||||
token_len = request.token_len_chunk
|
||||
mask_num = (
|
||||
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
||||
// self.block_size * self.block_size)
|
||||
// self.block_size
|
||||
* self.block_size
|
||||
)
|
||||
num_required_tokens = token_len - mask_num
|
||||
|
||||
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
|
||||
@@ -373,8 +382,7 @@ class KVPoolWorker:
|
||||
ends = []
|
||||
keys = []
|
||||
first_flag = True
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, request.block_hashes, mask_num):
|
||||
for start, end, key in self.token_database.process_tokens(token_len, request.block_hashes, mask_num):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
@@ -386,16 +394,16 @@ class KVPoolWorker:
|
||||
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
if not first_flag:
|
||||
is_finish = self.get_event.wait(timeout=3) #try---cache
|
||||
is_finish = self.get_event.wait(timeout=3) # try---cache
|
||||
if not is_finish:
|
||||
logger.info("Layerwise get failed")
|
||||
self.get_event.clear()
|
||||
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
||||
keys_multi_chunk, starts,
|
||||
ends, request.block_ids,
|
||||
layer_id)
|
||||
req_meta = LasyerMultiBlockReqMeta(
|
||||
request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id
|
||||
)
|
||||
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
req_meta
|
||||
) # type: ignore[union-attr, call-arg, arg-type]
|
||||
first_flag = False
|
||||
yield None
|
||||
else:
|
||||
@@ -405,16 +413,14 @@ class KVPoolWorker:
|
||||
yield None
|
||||
|
||||
retrieved_tokens = torch.sum(ret_mask)
|
||||
logger.debug(f"Retrieved {retrieved_tokens} "
|
||||
f"out of {num_required_tokens} "
|
||||
f"out of total {token_len} tokens")
|
||||
logger.debug(f"Retrieved {retrieved_tokens} out of {num_required_tokens} out of total {token_len} tokens")
|
||||
|
||||
yield ret_mask
|
||||
|
||||
def store_layer(
|
||||
self,
|
||||
request: ReqMeta,
|
||||
current_event: Optional[torch.npu.Event],
|
||||
current_event: torch.npu.Event | None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Store the KV cache in a layerwise manner.
|
||||
@@ -439,69 +445,88 @@ class KVPoolWorker:
|
||||
starts = []
|
||||
ends = []
|
||||
keys = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
request.token_len_chunk, request.block_hashes):
|
||||
for start, end, key in self.token_database.process_tokens(request.token_len_chunk, request.block_hashes):
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
starts.append(start)
|
||||
ends.append(end)
|
||||
keys.append(keys_multi_layer) #[block_num,layer_num]
|
||||
keys.append(keys_multi_layer) # [block_num,layer_num]
|
||||
|
||||
if keys:
|
||||
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
|
||||
keys = [list(row) for row in zip(*keys)] # [layer_num,block_num]
|
||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
||||
keys_multi_chunk, starts,
|
||||
ends, request.block_ids,
|
||||
layer_id,
|
||||
request.is_last_chunk,
|
||||
current_event)
|
||||
req_meta = LasyerMultiBlockReqMeta(
|
||||
request.req_id,
|
||||
keys_multi_chunk,
|
||||
starts,
|
||||
ends,
|
||||
request.block_ids,
|
||||
layer_id,
|
||||
request.is_last_chunk,
|
||||
current_event,
|
||||
)
|
||||
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
||||
req_meta
|
||||
) # type: ignore[union-attr, call-arg, arg-type]
|
||||
yield
|
||||
else:
|
||||
for layer_id in range(self.num_layers):
|
||||
yield
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str], meta:AscendConnectorMetadata) -> tuple[set[str], set[str]]:
|
||||
def get_finished(self, finished_req_ids: set[str], meta: AscendConnectorMetadata) -> tuple[set[str], set[str]]:
|
||||
done_sending = (
|
||||
self.get_and_clear_finished_requests(
|
||||
finished_req_ids, meta # type: ignore[union-attr]
|
||||
) if self.kv_role in ['kv_producer', 'kv_both']
|
||||
or self.consumer_is_to_put else set())
|
||||
finished_req_ids,
|
||||
meta, # type: ignore[union-attr]
|
||||
)
|
||||
if self.kv_role in ["kv_producer", "kv_both"] or self.consumer_is_to_put
|
||||
else set()
|
||||
)
|
||||
|
||||
done_recving = (
|
||||
self.kv_recv_thread.
|
||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
) if self.load_async else set())
|
||||
self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||
)
|
||||
if self.load_async
|
||||
else set()
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Number of completed KV cache send requests: %d, receive "
|
||||
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
|
||||
self.tp_rank)
|
||||
"Number of completed KV cache send requests: %d, receive requests: %d, tp_rank:%d",
|
||||
len(done_sending),
|
||||
len(done_recving),
|
||||
self.tp_rank,
|
||||
)
|
||||
return done_sending, done_recving
|
||||
|
||||
def get_and_clear_finished_requests(self, finished_req_ids, meta:AscendConnectorMetadata) -> set[str]:
|
||||
def get_and_clear_finished_requests(self, finished_req_ids, meta: AscendConnectorMetadata) -> set[str]:
|
||||
finished_sending = set()
|
||||
for req_id in meta.preempted_req_ids:
|
||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
req_id
|
||||
)
|
||||
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
|
||||
):
|
||||
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
|
||||
req_id] == 0 and req_id in self.finished_store_req:
|
||||
if (
|
||||
self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
|
||||
req_id
|
||||
]
|
||||
== 0
|
||||
and req_id in self.finished_store_req
|
||||
):
|
||||
self.finished_store_req.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
req_id
|
||||
)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
|
||||
req_id)
|
||||
req_id
|
||||
)
|
||||
if req_remain_jobs == 0:
|
||||
finished_sending.add(req_id)
|
||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||
req_id)
|
||||
req_id
|
||||
)
|
||||
elif req_remain_jobs is not None:
|
||||
self.finished_store_req.add(req_id)
|
||||
|
||||
@@ -522,8 +547,7 @@ class KVPoolWorker:
|
||||
keys = []
|
||||
try:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes):
|
||||
for start, end, key in self.token_database.process_tokens(token_len, block_hashes):
|
||||
if use_layerwise:
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
@@ -560,8 +584,7 @@ class KVPoolWorker:
|
||||
keys = []
|
||||
try:
|
||||
starts = []
|
||||
for start, end, key in self.token_database.process_tokens(
|
||||
token_len, block_hashes):
|
||||
for start, end, key in self.token_database.process_tokens(token_len, block_hashes):
|
||||
if use_layerwise:
|
||||
keys_multi_layer = key.split_layers(self.num_layers)
|
||||
for item in keys_multi_layer:
|
||||
@@ -574,25 +597,25 @@ class KVPoolWorker:
|
||||
for i in range(1, min(self.tp_size, self.num_kv_head)):
|
||||
for item in keys:
|
||||
new_str = item.replace( # type: ignore[attr-defined]
|
||||
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
|
||||
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1
|
||||
)
|
||||
multi_tp_keys.append(new_str)
|
||||
|
||||
for i in range(1, self.pp_size):
|
||||
for item in keys:
|
||||
new_str = item.replace( # type: ignore[attr-defined]
|
||||
"@pp_rank:0", f"@pp_rank:{i}", 1)
|
||||
"@pp_rank:0", f"@pp_rank:{i}", 1
|
||||
)
|
||||
multi_tp_keys.append(new_str)
|
||||
|
||||
res = self.m_store.exists(
|
||||
multi_tp_keys) # type: ignore[assignment]
|
||||
res = self.m_store.exists(multi_tp_keys) # type: ignore[assignment]
|
||||
num_block = len(keys)
|
||||
if use_layerwise:
|
||||
res = self.check_all_layers_exists(res, self.num_layers)
|
||||
num_block = len(keys) // self.num_layers
|
||||
multi_tp_values = [
|
||||
res[i * num_block:(i + 1) * num_block] # type: ignore[index]
|
||||
for i in range(
|
||||
min(self.tp_size, self.num_kv_head) * self.pp_size)
|
||||
res[i * num_block : (i + 1) * num_block] # type: ignore[index]
|
||||
for i in range(min(self.tp_size, self.num_kv_head) * self.pp_size)
|
||||
]
|
||||
index = self.find_min_first_non_one_index(multi_tp_values)
|
||||
if index != -1:
|
||||
@@ -603,8 +626,7 @@ class KVPoolWorker:
|
||||
return start
|
||||
return end
|
||||
|
||||
def check_all_layers_exists(self, res: list[int],
|
||||
num_layers: int) -> list[int]:
|
||||
def check_all_layers_exists(self, res: list[int], num_layers: int) -> list[int]:
|
||||
total_chunks = len(res) // num_layers
|
||||
result = []
|
||||
|
||||
@@ -618,7 +640,6 @@ class KVPoolWorker:
|
||||
|
||||
def find_min_first_non_one_index(self, arr):
|
||||
try:
|
||||
return min(idx for row in arr for idx, val in enumerate(row)
|
||||
if val != 1)
|
||||
return min(idx for row in arr for idx, val in enumerate(row) if val != 1)
|
||||
except ValueError:
|
||||
return -1
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import \
|
||||
get_manager_for_kv_cache_spec
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.single_type_kv_cache_manager import get_manager_for_kv_cache_spec
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.metrics.stats import (PrefixCacheStats, CachingMetrics)
|
||||
from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class CPUCacheStats:
|
||||
|
||||
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.log_stats = log_stats
|
||||
@@ -27,10 +24,9 @@ class CPUCacheStats:
|
||||
# Log the prefix cache hit rate every 10 seconds.
|
||||
if current_time_sec - self.time_sec >= 10:
|
||||
self.time_sec = current_time_sec
|
||||
logger.info("CPU Prefix cache hit rate: %.1f%%",
|
||||
self.cpu_prefix_cache_metrics.hit_rate * 100)
|
||||
logger.info("CPU Prefix cache hit rate: %.1f%%", self.cpu_prefix_cache_metrics.hit_rate * 100)
|
||||
|
||||
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
|
||||
def make_prefix_cache_stats(self) -> PrefixCacheStats | None:
|
||||
"""Get (and reset) the prefix cache stats.
|
||||
Returns:
|
||||
The current prefix caching stats, or None if logging is disabled.
|
||||
@@ -57,7 +53,6 @@ class CPUCacheStats:
|
||||
|
||||
|
||||
class CPUKVCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
@@ -70,30 +65,26 @@ class CPUKVCacheManager:
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
self.use_eagle = use_eagle
|
||||
self.block_pool = BlockPool(self.num_cpu_blocks, True,
|
||||
enable_kv_cache_events)
|
||||
self.block_pool = BlockPool(self.num_cpu_blocks, True, enable_kv_cache_events)
|
||||
self.single_type_manager = get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
# Record kv block hashes, avoid redundant computation.
|
||||
self.req_to_block_hashes: defaultdict[
|
||||
str, list[BlockHash]] = defaultdict(list)
|
||||
self.req_to_block_hashes: defaultdict[str, list[BlockHash]] = defaultdict(list)
|
||||
# Record blocks touched in get_matched_num_and_touch().
|
||||
self.req_to_computed_blocks: defaultdict[
|
||||
str, list[KVCacheBlock]] = defaultdict(list)
|
||||
self.req_to_computed_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list)
|
||||
# Record the request that failed to allocate.
|
||||
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
|
||||
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
|
||||
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
|
||||
log_stats=True)
|
||||
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True, log_stats=True)
|
||||
# Record request that will be free after finish sending
|
||||
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
|
||||
|
||||
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (request.sampling_params.prompt_logprobs is not None):
|
||||
if request.sampling_params.prompt_logprobs is not None:
|
||||
return 0, False
|
||||
request_id = request.request_id
|
||||
# The block hashes for the request may already be computed
|
||||
@@ -119,10 +110,8 @@ class CPUKVCacheManager:
|
||||
|
||||
# cup prefix cache status set and log
|
||||
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
|
||||
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
|
||||
num_computed_tokens)
|
||||
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
|
||||
self.cpu_cache_stats.prefix_cache_stats)
|
||||
self.cpu_cache_stats.set_cache_stats(request.num_tokens, num_computed_tokens)
|
||||
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(self.cpu_cache_stats.prefix_cache_stats)
|
||||
self.cpu_cache_stats.log()
|
||||
|
||||
return num_computed_tokens, False
|
||||
@@ -130,12 +119,10 @@ class CPUKVCacheManager:
|
||||
def _release_ahead_touch(self, request_id: str):
|
||||
computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
if computed_blocks:
|
||||
self.single_type_manager.block_pool.free_blocks(
|
||||
reversed(computed_blocks))
|
||||
self.single_type_manager.block_pool.free_blocks(reversed(computed_blocks))
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
|
||||
def allocate_slots(self, req_to_num_tokens: dict[str, int],
|
||||
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
|
||||
def allocate_slots(self, req_to_num_tokens: dict[str, int], unallocated_req_ids: set[str]) -> dict[str, list[int]]:
|
||||
for request_id in unallocated_req_ids:
|
||||
self._free_slots(request_id)
|
||||
req_to_new_blocks = {}
|
||||
@@ -143,44 +130,34 @@ class CPUKVCacheManager:
|
||||
if self.req_failed_to_allocate[request_id]:
|
||||
continue
|
||||
new_computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
num_blocks_to_allocate = (
|
||||
self.single_type_manager.get_num_blocks_to_allocate(
|
||||
request_id=request_id,
|
||||
num_tokens=num_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
))
|
||||
num_blocks_to_allocate = self.single_type_manager.get_num_blocks_to_allocate(
|
||||
request_id=request_id,
|
||||
num_tokens=num_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
)
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
self._release_ahead_touch(request_id)
|
||||
self.req_failed_to_allocate[request_id] = True
|
||||
continue
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.single_type_manager.save_new_computed_blocks(
|
||||
request_id, new_computed_blocks)
|
||||
self.single_type_manager.save_new_computed_blocks(request_id, new_computed_blocks)
|
||||
# Allocate new blocks but do not cache now.
|
||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||
request_id, num_tokens)
|
||||
new_blocks = self.single_type_manager.allocate_new_blocks(request_id, num_tokens)
|
||||
self.req_to_num_tokens[request_id] = num_tokens
|
||||
# No need to release ref_cnt because we use officially.
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
req_to_new_blocks[request_id] = [
|
||||
block.block_id for block in new_computed_blocks + new_blocks
|
||||
]
|
||||
req_to_new_blocks[request_id] = [block.block_id for block in new_computed_blocks + new_blocks]
|
||||
return req_to_new_blocks
|
||||
|
||||
def record_request_cache_and_free_slots(self, request: Request):
|
||||
logger.debug(
|
||||
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
logger.debug(f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager")
|
||||
self.req_to_free[request.request_id] = request
|
||||
|
||||
def cache_and_free_slots(self, request_id: str):
|
||||
logger.debug(
|
||||
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
logger.debug(f"Cache and free slots for request {request_id} in cpu_kv_cache_manager")
|
||||
if request_id not in self.req_to_free:
|
||||
logger.Error(
|
||||
f"request {request_id} not in req_to_free, maybe bug!")
|
||||
logger.Error(f"request {request_id} not in req_to_free, maybe bug!")
|
||||
return
|
||||
request = self.req_to_free[request_id]
|
||||
if not self.req_failed_to_allocate[request_id]:
|
||||
@@ -189,8 +166,7 @@ class CPUKVCacheManager:
|
||||
self.req_to_num_tokens[request_id],
|
||||
)
|
||||
self._free_slots(request_id)
|
||||
logger.debug(
|
||||
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
|
||||
logger.debug(f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
|
||||
del self.req_to_free[request_id]
|
||||
|
||||
def _free_slots(self, request_id: str):
|
||||
|
||||
@@ -5,15 +5,15 @@ import queue
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.attention.layer import Attention, MLAAttention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
@@ -23,12 +23,14 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.metadata import (
|
||||
MetadataServer, MetadataServerProc, MLAConfig)
|
||||
|
||||
MetadataServer,
|
||||
MetadataServerProc,
|
||||
MLAConfig,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.attention.backend import AttentionMetadata #type: ignore
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
@@ -59,20 +61,15 @@ class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
|
||||
class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None):
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None
|
||||
):
|
||||
self._connector_metadata = CPUOffloadingConnectorMetadata(requests={}, finished_req_ids=set())
|
||||
if not vllm_config.cache_config.enable_prefix_caching:
|
||||
self.connector_scheduler: Optional[
|
||||
CPUOffloadingConnectorScheduler] = None
|
||||
self.connector_worker: Optional[
|
||||
CPUOffloadingConnectorWorker] = None
|
||||
self.connector_scheduler: CPUOffloadingConnectorScheduler | None = None
|
||||
self.connector_worker: CPUOffloadingConnectorWorker | None = None
|
||||
elif role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = CPUOffloadingConnectorScheduler(
|
||||
vllm_config)
|
||||
self.connector_scheduler = CPUOffloadingConnectorScheduler(vllm_config)
|
||||
self.connector_worker = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
@@ -82,11 +79,9 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
if self.connector_worker is not None:
|
||||
assert isinstance(connector_metadata,
|
||||
CPUOffloadingConnectorMetadata)
|
||||
assert isinstance(connector_metadata, CPUOffloadingConnectorMetadata)
|
||||
self.connector_worker.bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
@@ -97,8 +92,7 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.start_load_kv()
|
||||
|
||||
@@ -106,53 +100,42 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
def save_kv_layer(
|
||||
self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str] | None, set[str] | None]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(), None
|
||||
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens)
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.update_state_after_alloc(request)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.build_connector_meta(
|
||||
scheduler_output)
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
return KVConnectorMetadata()
|
||||
|
||||
def request_finished(
|
||||
self, request: "Request",
|
||||
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
def request_finished(self, request: "Request", block_ids: list[int]) -> tuple[bool, dict[str, Any] | None]:
|
||||
if self.connector_scheduler is not None:
|
||||
self.connector_scheduler.request_finished(request)
|
||||
return True, None
|
||||
|
||||
|
||||
class CPUOffloadingConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorScheduler")
|
||||
self.vllm_config = vllm_config
|
||||
@@ -165,22 +148,17 @@ class CPUOffloadingConnectorScheduler:
|
||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||
self.zmq_rpc_client.call("post_init")
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"swap_in_threshold", 0)
|
||||
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config("swap_in_threshold", 0)
|
||||
else:
|
||||
self.swap_in_threshold = 0
|
||||
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, ori_request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
def get_num_new_matched_tokens(self, ori_request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
|
||||
request = copy.deepcopy(ori_request)
|
||||
request.get_hash_new_full_blocks = None
|
||||
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
|
||||
"get_matched_num_and_touch", request)
|
||||
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call("get_matched_num_and_touch", request)
|
||||
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
|
||||
self.num_cpu_computed_tokens[
|
||||
request.request_id] = num_cpu_computed_tokens
|
||||
self.num_cpu_computed_tokens[request.request_id] = num_cpu_computed_tokens
|
||||
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
|
||||
return num_cpu_computed_tokens - num_computed_tokens, load_async
|
||||
else:
|
||||
@@ -189,29 +167,22 @@ class CPUOffloadingConnectorScheduler:
|
||||
def update_state_after_alloc(self, request: "Request"):
|
||||
self.allocated_req_ids.add(request.request_id)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
num_tokens = {}
|
||||
# process scheduled_new_reqs
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req.req_id
|
||||
num_tokens[req_id] = (
|
||||
req.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
num_tokens[req_id] = req.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]
|
||||
|
||||
# process scheduled_cached_reqs
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_tokens[req_id] = (
|
||||
cached_reqs.num_computed_tokens[idx] +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
num_tokens[req_id] = cached_reqs.num_computed_tokens[idx] + scheduler_output.num_scheduled_tokens[req_id]
|
||||
|
||||
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
|
||||
self.allocated_req_ids -
|
||||
scheduler_output.num_scheduled_tokens.keys())
|
||||
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
|
||||
num_tokens,
|
||||
unallocated_req_ids)
|
||||
unallocated_req_ids = set(
|
||||
self.num_gpu_computed_tokens.keys() - self.allocated_req_ids - scheduler_output.num_scheduled_tokens.keys()
|
||||
)
|
||||
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", num_tokens, unallocated_req_ids)
|
||||
metadata = CPUOffloadingConnectorMetadata(
|
||||
requests={},
|
||||
finished_req_ids=set(self.finished_req_ids),
|
||||
@@ -222,22 +193,22 @@ class CPUOffloadingConnectorScheduler:
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=req.num_computed_tokens,
|
||||
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
|
||||
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
|
||||
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id],
|
||||
)
|
||||
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
gpu_block_ids = cached_reqs.new_block_ids[idx]
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
|
||||
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
)
|
||||
self.num_gpu_computed_tokens.clear()
|
||||
self.num_cpu_computed_tokens.clear()
|
||||
self.allocated_req_ids.clear()
|
||||
@@ -249,12 +220,10 @@ class CPUOffloadingConnectorScheduler:
|
||||
request.get_hash_new_full_blocks = None
|
||||
self.finished_req_ids.append(request.request_id)
|
||||
# inform metadata server to record request, and free it after finish sending
|
||||
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
|
||||
request)
|
||||
self.zmq_rpc_client.call("record_request_cache_and_free_slots", request)
|
||||
|
||||
|
||||
class CPUOffloadingConnectorWorker:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorWorker")
|
||||
self.vllm_config = vllm_config
|
||||
@@ -289,7 +258,7 @@ class CPUOffloadingConnectorWorker:
|
||||
def init_metadata_server(self, vllm_config: VllmConfig):
|
||||
self.metadata_thread = threading.Thread(
|
||||
target=MetadataServerProc.run_metadata_server,
|
||||
args=(vllm_config, ),
|
||||
args=(vllm_config,),
|
||||
)
|
||||
self.metadata_thread.daemon = True
|
||||
self.metadata_thread.start()
|
||||
@@ -304,18 +273,15 @@ class CPUOffloadingConnectorWorker:
|
||||
logger.info(f"wait for metadata server to start, error: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
|
||||
def bind_connector_metadata(self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
|
||||
for req_id, req in connector_metadata.requests.items():
|
||||
if req_id in self.requests:
|
||||
self.requests[req_id].update(req)
|
||||
req = self.requests[req_id]
|
||||
else:
|
||||
self.requests[req_id] = req
|
||||
for i in range(req.num_gpu_computed_tokens // self.block_size,
|
||||
req.num_computed_tokens // self.block_size):
|
||||
self.load_block_mapping.append(
|
||||
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
|
||||
for i in range(req.num_gpu_computed_tokens // self.block_size, req.num_computed_tokens // self.block_size):
|
||||
self.load_block_mapping.append((req.cpu_block_ids[i], req.gpu_block_ids[i]))
|
||||
for req_id in connector_metadata.finished_req_ids:
|
||||
if req_id in self.requests:
|
||||
self.save_input_queue.put((req_id, self.requests[req_id]))
|
||||
@@ -326,11 +292,11 @@ class CPUOffloadingConnectorWorker:
|
||||
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
|
||||
self.gpu_kv_caches = kv_caches
|
||||
model_config = self.vllm_config.model_config
|
||||
mla_config: Optional[MLAConfig] = None
|
||||
mla_config: MLAConfig | None = None
|
||||
if model_config.use_mla:
|
||||
mla_config = MLAConfig(
|
||||
model_config.hf_text_config.kv_lora_rank,
|
||||
model_config.hf_text_config.qk_rope_head_dim)
|
||||
model_config.hf_text_config.kv_lora_rank, model_config.hf_text_config.qk_rope_head_dim
|
||||
)
|
||||
self.cpu_kv_caches = list(
|
||||
self.zmq_rpc_client.call(
|
||||
"init_cpu_kv_caches",
|
||||
@@ -338,7 +304,8 @@ class CPUOffloadingConnectorWorker:
|
||||
self.tp_rank,
|
||||
get_kv_cache_spec(self.vllm_config),
|
||||
mla_config,
|
||||
).values())
|
||||
).values()
|
||||
)
|
||||
|
||||
def start_load_kv(self) -> None:
|
||||
self.current_layer = 0
|
||||
@@ -358,10 +325,8 @@ class CPUOffloadingConnectorWorker:
|
||||
cpu_kv_caches = self.cpu_kv_caches[layer]
|
||||
with torch.npu.stream(self.load_stream):
|
||||
for cpu_block_id, gpu_block_id in self.load_block_mapping:
|
||||
for gpu_layer_part, cpu_layer_part in zip(
|
||||
gpu_kv_caches, cpu_kv_caches):
|
||||
gpu_layer_part[gpu_block_id].copy_(
|
||||
cpu_layer_part[cpu_block_id], non_blocking=True)
|
||||
for gpu_layer_part, cpu_layer_part in zip(gpu_kv_caches, cpu_kv_caches):
|
||||
gpu_layer_part[gpu_block_id].copy_(cpu_layer_part[cpu_block_id], non_blocking=True)
|
||||
|
||||
def get_finished(self) -> set[str]:
|
||||
done_sending: set[str] = set()
|
||||
@@ -380,8 +345,7 @@ class CPUOffloadingConnectorWorker:
|
||||
self.done_sending_count[req_id] += 1
|
||||
other_ranks_finished_ids: list[str] = []
|
||||
for i in range(1, self.tp_world_size):
|
||||
other_ranks_finished_ids.extend(
|
||||
self.tp_group.recv_object(src=i))
|
||||
other_ranks_finished_ids.extend(self.tp_group.recv_object(src=i))
|
||||
for req_id in other_ranks_finished_ids:
|
||||
self.done_sending_count[req_id] += 1
|
||||
all_done_sending: set[str] = set()
|
||||
@@ -391,8 +355,7 @@ class CPUOffloadingConnectorWorker:
|
||||
all_done_sending.add(req_id)
|
||||
# release cpu_kv_cache after request sending finished
|
||||
# to avoid rpc blocking, use thread to call rpc asynchronously
|
||||
sending_finished_thread = threading.Thread(
|
||||
target=self._sending_finished, args=(all_done_sending, ))
|
||||
sending_finished_thread = threading.Thread(target=self._sending_finished, args=(all_done_sending,))
|
||||
sending_finished_thread.daemon = True
|
||||
sending_finished_thread.start()
|
||||
|
||||
@@ -411,11 +374,10 @@ class CPUOffloadingConnectorWorker:
|
||||
while True:
|
||||
req_id, req = self.save_input_queue.get()
|
||||
for i in range(
|
||||
req.num_cpu_computed_tokens // self.block_size,
|
||||
min((req.num_computed_tokens + req.num_scheduled_tokens) //
|
||||
self.block_size, len(req.cpu_block_ids))):
|
||||
save_block_mapping.append(
|
||||
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
|
||||
req.num_cpu_computed_tokens // self.block_size,
|
||||
min((req.num_computed_tokens + req.num_scheduled_tokens) // self.block_size, len(req.cpu_block_ids)),
|
||||
):
|
||||
save_block_mapping.append((req.gpu_block_ids[i], req.cpu_block_ids[i]))
|
||||
with torch.npu.stream(self.save_stream):
|
||||
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
|
||||
# non-MLA: kv_layer is list[tensor], typically means [k, v].
|
||||
@@ -425,13 +387,9 @@ class CPUOffloadingConnectorWorker:
|
||||
start, step = 0, 1
|
||||
for i in range(start, len(save_block_mapping), step):
|
||||
gpu_block_id, cpu_block_id = save_block_mapping[i]
|
||||
for cpu_kv_caches, gpu_kv_caches in zip(
|
||||
self.cpu_kv_caches, self.gpu_kv_caches.values()):
|
||||
for cpu_layer_part, gpu_layer_part in zip(
|
||||
cpu_kv_caches, gpu_kv_caches):
|
||||
cpu_layer_part[cpu_block_id].copy_(
|
||||
gpu_layer_part[gpu_block_id],
|
||||
non_blocking=True)
|
||||
for cpu_kv_caches, gpu_kv_caches in zip(self.cpu_kv_caches, self.gpu_kv_caches.values()):
|
||||
for cpu_layer_part, gpu_layer_part in zip(cpu_kv_caches, gpu_kv_caches):
|
||||
cpu_layer_part[cpu_block_id].copy_(gpu_layer_part[gpu_block_id], non_blocking=True)
|
||||
self.save_stream.synchronize()
|
||||
self.save_output_queue.put(req_id)
|
||||
save_block_mapping.clear()
|
||||
@@ -453,8 +411,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
if vllm_config.cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
vllm_config.cache_config.cache_dtype]
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[vllm_config.cache_config.cache_dtype]
|
||||
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
|
||||
@@ -472,10 +429,8 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
# using DSA. Fix the spec in vLLM is the final way.
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype)
|
||||
block_size=block_size, num_kv_heads=1, head_size=attn_module.head_size, dtype=kv_cache_dtype
|
||||
)
|
||||
elif spec := attn_module.get_kv_cache_spec(vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
|
||||
@@ -484,8 +439,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
|
||||
if len(mamba_layers) > 0:
|
||||
if vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
raise NotImplementedError("Prefix caching is not supported for Mamba yet.")
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
if spec := mamba_module.get_kv_cache_spec(vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import vllm.envs as envs
|
||||
@@ -14,8 +15,7 @@ from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import \
|
||||
CPUKVCacheManager
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import CPUKVCacheManager
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,8 +30,7 @@ def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
|
||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||
return kv_transfer_config
|
||||
elif kv_transfer_config.kv_connector == "MultiConnector":
|
||||
ktcs = kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors")
|
||||
ktcs = kv_transfer_config.kv_connector_extra_config.get("connectors")
|
||||
for ktc in ktcs:
|
||||
kv_transfer_config = KVTransferConfig(**ktc)
|
||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||
@@ -44,7 +43,6 @@ class MetadataServer:
|
||||
DEFAULT_CPU_SWAP_SPACE_GB = 800
|
||||
|
||||
class ZMQRPCClient:
|
||||
|
||||
def __init__(self, identity=None):
|
||||
if identity is None:
|
||||
identity = f"worker-{os.getpid()}-{id(self)}"
|
||||
@@ -56,7 +54,8 @@ class MetadataServer:
|
||||
zmq.DEALER, # type: ignore
|
||||
bind=False,
|
||||
identity=identity.encode(),
|
||||
linger=0)
|
||||
linger=0,
|
||||
)
|
||||
|
||||
def call(self, func_name: str, *args, **kwargs) -> Any:
|
||||
request = (func_name, args, kwargs)
|
||||
@@ -74,11 +73,9 @@ class MetadataServer:
|
||||
self.shared_memory_dict = memory_dict
|
||||
result = {}
|
||||
for key, shm in memory_dict.items():
|
||||
tensor = torch.frombuffer(
|
||||
shm.buf, dtype=layer_dtype).reshape(layer_size)
|
||||
tensor = torch.frombuffer(shm.buf, dtype=layer_dtype).reshape(layer_size)
|
||||
if mla_config is not None:
|
||||
tensor = tensor.split(
|
||||
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
|
||||
tensor = tensor.split([mla_config.nope_dim, mla_config.rope_dim], dim=-1)
|
||||
result[key] = tensor
|
||||
return result
|
||||
|
||||
@@ -86,7 +83,7 @@ class MetadataServer:
|
||||
# will be finalized by outer process
|
||||
self.socket.close()
|
||||
self.ctx.term()
|
||||
if hasattr(self, 'shared_memory_dict'):
|
||||
if hasattr(self, "shared_memory_dict"):
|
||||
for shm in self.shared_memory_dict.values():
|
||||
shm.close()
|
||||
|
||||
@@ -96,7 +93,8 @@ class MetadataServer:
|
||||
kv_transfer_config = get_cpu_offload_connector(vllm_config)
|
||||
assert kv_transfer_config is not None
|
||||
available_memory_gb = kv_transfer_config.get_from_extra_config(
|
||||
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
|
||||
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB
|
||||
)
|
||||
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
|
||||
logger.info(f"cpu swap space: {self.available_memory} bytes")
|
||||
self.ctx = zmq.Context() # type: ignore
|
||||
@@ -105,7 +103,8 @@ class MetadataServer:
|
||||
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||
zmq.ROUTER, # type: ignore
|
||||
bind=True,
|
||||
linger=0)
|
||||
linger=0,
|
||||
)
|
||||
self.functions: dict[str, Callable] = {
|
||||
"init_cpu_kv_caches": self.init_cpu_kv_caches,
|
||||
"post_init": self.post_init,
|
||||
@@ -133,15 +132,11 @@ class MetadataServer:
|
||||
tp_rank: int,
|
||||
kv_cache_specs: dict[str, AttentionSpec],
|
||||
mla_config: MLAConfig,
|
||||
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
|
||||
MLAConfig]:
|
||||
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype, MLAConfig]:
|
||||
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
|
||||
# follow the assumption that each layer has the same spec
|
||||
layer = next(iter(kv_cache_specs.values()))
|
||||
assert all([
|
||||
layer.page_size_bytes == any.page_size_bytes
|
||||
for any in kv_cache_specs.values()
|
||||
])
|
||||
assert all([layer.page_size_bytes == any.page_size_bytes for any in kv_cache_specs.values()])
|
||||
use_mla = isinstance(layer, MLAAttentionSpec)
|
||||
# mla shares the same kv cache among different tp
|
||||
if use_mla:
|
||||
@@ -154,30 +149,24 @@ class MetadataServer:
|
||||
available_memory //= self.pipeline_parallel_size
|
||||
available_memory //= len(kv_cache_specs)
|
||||
num_blocks = available_memory // layer.page_size_bytes
|
||||
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
|
||||
layer.head_size) # type: ignore
|
||||
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore
|
||||
else:
|
||||
available_memory //= self.world_size
|
||||
available_memory //= len(kv_cache_specs)
|
||||
num_blocks = available_memory // layer.page_size_bytes
|
||||
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
|
||||
layer.head_size) # type: ignore
|
||||
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore
|
||||
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
|
||||
for layer_name in kv_cache_specs.keys():
|
||||
for layer_name in kv_cache_specs:
|
||||
# only this format can share during ZeroMQ+pickle
|
||||
shared_memory_dict[
|
||||
layer_name] = MetadataServer._safe_create_shared_memory(
|
||||
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
|
||||
shared_memory_dict[layer_name] = MetadataServer._safe_create_shared_memory(
|
||||
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes
|
||||
)
|
||||
if use_mla:
|
||||
assert mla_config is not None
|
||||
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
|
||||
self.shared_memory[(pp_rank,
|
||||
tp_rank)] = (shared_memory_dict, layer_size,
|
||||
layer.dtype, mla_config)
|
||||
self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, mla_config)
|
||||
else:
|
||||
self.shared_memory[(pp_rank,
|
||||
tp_rank)] = (shared_memory_dict, layer_size,
|
||||
layer.dtype, None)
|
||||
self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, None)
|
||||
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
|
||||
self.num_cpu_blocks = num_blocks
|
||||
self.layer = layer
|
||||
@@ -185,23 +174,20 @@ class MetadataServer:
|
||||
|
||||
def post_init(self):
|
||||
# different processors in data parallel may call multiple times
|
||||
if hasattr(self, 'cpu_block_manager'):
|
||||
if hasattr(self, "cpu_block_manager"):
|
||||
return
|
||||
# do shared_memory() at least once
|
||||
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
|
||||
assert self.num_cpu_blocks >= 0
|
||||
self.cpu_block_manager = CPUKVCacheManager(self.layer,
|
||||
self.num_cpu_blocks)
|
||||
self.functions.update({
|
||||
"get_matched_num_and_touch":
|
||||
self.cpu_block_manager.get_matched_num_and_touch,
|
||||
"allocate_slots":
|
||||
self.cpu_block_manager.allocate_slots,
|
||||
"record_request_cache_and_free_slots":
|
||||
self.cpu_block_manager.record_request_cache_and_free_slots,
|
||||
"cache_and_free_slots":
|
||||
self.cpu_block_manager.cache_and_free_slots,
|
||||
})
|
||||
self.cpu_block_manager = CPUKVCacheManager(self.layer, self.num_cpu_blocks)
|
||||
self.functions.update(
|
||||
{
|
||||
"get_matched_num_and_touch": self.cpu_block_manager.get_matched_num_and_touch,
|
||||
"allocate_slots": self.cpu_block_manager.allocate_slots,
|
||||
"record_request_cache_and_free_slots": self.cpu_block_manager.record_request_cache_and_free_slots,
|
||||
"cache_and_free_slots": self.cpu_block_manager.cache_and_free_slots,
|
||||
}
|
||||
)
|
||||
|
||||
def serve_step(self):
|
||||
client_id = self.socket.recv()
|
||||
@@ -228,8 +214,7 @@ class MetadataServer:
|
||||
def shutdown(self):
|
||||
self.socket.close()
|
||||
self.ctx.term()
|
||||
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
|
||||
"ipc://", "")
|
||||
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace("ipc://", "")
|
||||
if os.path.exists(socket_path):
|
||||
os.remove(socket_path)
|
||||
for cached in self.shared_memory.values():
|
||||
@@ -239,11 +224,9 @@ class MetadataServer:
|
||||
|
||||
|
||||
class MetadataServerProc:
|
||||
|
||||
@staticmethod
|
||||
def run_metadata_server(vllm_config: VllmConfig):
|
||||
if (not vllm_config.cache_config.enable_prefix_caching
|
||||
or get_cpu_offload_connector(vllm_config) is None):
|
||||
if not vllm_config.cache_config.enable_prefix_caching or get_cpu_offload_connector(vllm_config) is None:
|
||||
return
|
||||
|
||||
shutdown_requested = False
|
||||
@@ -257,7 +240,7 @@ class MetadataServerProc:
|
||||
# Either SIGTERM or SIGINT will terminate the worker
|
||||
# signal.signal(signal.SIGTERM, _signal_handler)
|
||||
# signal.signal(signal.SIGINT, _signal_handler)
|
||||
metadata_server: Optional[MetadataServer] = None
|
||||
metadata_server: MetadataServer | None = None
|
||||
try:
|
||||
metadata_server = MetadataServer(vllm_config)
|
||||
logger.info("Metadata server started.")
|
||||
|
||||
@@ -4,19 +4,21 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
import torch
|
||||
from ucm.integration.vllm.ucm_connector import UCMConnector
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# isort: off
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics, KVConnectorStats, PromMetric, PromMetricT)
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@@ -25,16 +27,13 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class UCMConnectorV1(KVConnectorBase_V1):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config)
|
||||
super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
|
||||
ImplCls = UCMConnector
|
||||
@@ -60,8 +59,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
|
||||
"""
|
||||
self._ucm_engine.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs: Any) -> None:
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
@@ -110,8 +108,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
|
||||
**kwargs)
|
||||
self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
|
||||
|
||||
def wait_for_save(self) -> None:
|
||||
"""
|
||||
@@ -131,8 +128,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
|
||||
"""
|
||||
self._ucm_engine.clear_connector_metadata()
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
@@ -175,20 +171,15 @@ class UCMConnectorV1(KVConnectorBase_V1):
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
return self._ucm_engine.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
return self._ucm_engine.get_num_new_matched_tokens(request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int) -> None:
|
||||
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int) -> None:
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
self._ucm_engine.update_state_after_alloc(request, blocks,
|
||||
num_external_tokens)
|
||||
self._ucm_engine.update_state_after_alloc(request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
@@ -222,10 +213,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
|
||||
# ==============================
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls,
|
||||
data: dict[str, Any] | None = None
|
||||
) -> Optional["KVConnectorStats"]:
|
||||
def build_kv_connector_stats(cls, data: dict[str, Any] | None = None) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
import ipaddress
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
|
||||
|
||||
class GlobalTE():
|
||||
|
||||
class GlobalTE:
|
||||
def __init__(self):
|
||||
self.transfer_engine = None
|
||||
self.is_register_buffer: bool = False
|
||||
self.transfer_engine_lock = threading.Lock()
|
||||
self.register_buffer_lock = threading.Lock()
|
||||
|
||||
def get_transfer_engine(self, hostname: str, device_name: Optional[str]):
|
||||
def get_transfer_engine(self, hostname: str, device_name: str | None):
|
||||
if self.transfer_engine is None:
|
||||
with self.transfer_engine_lock:
|
||||
# Double-Checked Locking
|
||||
@@ -22,12 +19,9 @@ class GlobalTE():
|
||||
raise RuntimeError("mooncake is not available")
|
||||
self.transfer_engine = TransferEngine()
|
||||
device_name = device_name if device_name is not None else ""
|
||||
ret_value = self.transfer_engine.initialize(
|
||||
hostname, "P2PHANDSHAKE", "ascend", device_name)
|
||||
ret_value = self.transfer_engine.initialize(hostname, "P2PHANDSHAKE", "ascend", device_name)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError(
|
||||
f"TransferEngine initialization failed with ret_value: {ret_value}"
|
||||
)
|
||||
raise RuntimeError(f"TransferEngine initialization failed with ret_value: {ret_value}")
|
||||
return self.transfer_engine
|
||||
|
||||
def register_buffer(self, ptrs: list[int], sizes: list[int]):
|
||||
|
||||
@@ -6,8 +6,7 @@ import torch.distributed as dist
|
||||
from vllm_ascend.distributed.parallel_state import get_p_tp_group
|
||||
|
||||
|
||||
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
|
||||
value: torch.TensorType):
|
||||
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, value: torch.TensorType):
|
||||
if pd_tp_ratio <= 1:
|
||||
return None, None
|
||||
elif key is None or value is None:
|
||||
@@ -20,22 +19,17 @@ def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
|
||||
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
|
||||
num_kv_heads = input_tensor.size(1)
|
||||
output_tensor = torch.zeros_like(input_tensor)
|
||||
dist.all_to_all_single(output_tensor,
|
||||
input_tensor,
|
||||
group=get_p_tp_group().device_group)
|
||||
dist.all_to_all_single(output_tensor, input_tensor, group=get_p_tp_group().device_group)
|
||||
input_tensor = 0
|
||||
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
|
||||
output_tensor = 0
|
||||
return result
|
||||
|
||||
|
||||
def rearrange_output(base_output: torch.Tensor, cut_num: int,
|
||||
num_kv_heads: int):
|
||||
def rearrange_output(base_output: torch.Tensor, cut_num: int, num_kv_heads: int):
|
||||
size_0 = base_output.size(0)
|
||||
if size_0 % cut_num != 0:
|
||||
raise ValueError(
|
||||
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
|
||||
)
|
||||
raise ValueError(f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]")
|
||||
chunk_size = size_0 // cut_num
|
||||
reshaped = base_output.view(cut_num, chunk_size, -1)
|
||||
transposed = reshaped.transpose(0, 1)
|
||||
@@ -46,16 +40,13 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||||
data_ptr = tensor.data_ptr()
|
||||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||
return tensor[int(offset):]
|
||||
return tensor[int(offset) :]
|
||||
|
||||
|
||||
def get_transfer_timeout_value():
|
||||
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
|
||||
if len(ascend_transfer_timeout) > 0:
|
||||
return int(ascend_transfer_timeout)
|
||||
hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT',
|
||||
'20')) # type: ignore
|
||||
hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT',
|
||||
'7')) # type: ignore
|
||||
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
|
||||
3000)
|
||||
hccl_rdma_timeout = int(os.getenv("HCCL_RDMA_TIMEOUT", "20")) # type: ignore
|
||||
hccl_rdma_retry_cnt = int(os.getenv("HCCL_RDMA_RETRY_CNT", "7")) # type: ignore
|
||||
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000)
|
||||
|
||||
@@ -4,8 +4,7 @@ from vllm.logger import init_logger
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backend import AttentionBackend # type: ignore
|
||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
|
||||
TransferResult, TransferSpec)
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingHandler, TransferResult, TransferSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -44,7 +43,6 @@ def expand_block_ids(
|
||||
|
||||
|
||||
class CpuNpuOffloadingHandler(OffloadingHandler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gpu_block_size: int,
|
||||
@@ -81,20 +79,22 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
|
||||
cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor
|
||||
|
||||
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
|
||||
self.cpu_tensors.append((
|
||||
torch.zeros(
|
||||
cpu_shape,
|
||||
dtype=gpu_tensor[0].dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
),
|
||||
torch.zeros(
|
||||
cpu_shape,
|
||||
dtype=gpu_tensor[0].dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
),
|
||||
))
|
||||
self.cpu_tensors.append(
|
||||
(
|
||||
torch.zeros(
|
||||
cpu_shape,
|
||||
dtype=gpu_tensor[0].dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
),
|
||||
torch.zeros(
|
||||
cpu_shape,
|
||||
dtype=gpu_tensor[0].dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
logger.info("start transfer_async...")
|
||||
@@ -123,9 +123,7 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
|
||||
dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor
|
||||
src_sub_block_count = src_blocks.size * src_block_size_factor
|
||||
|
||||
assert (
|
||||
src_sub_block_count == dst_blocks.size * dst_block_size_factor -
|
||||
dst_sub_blocks_to_skip)
|
||||
assert src_sub_block_count == dst_blocks.size * dst_block_size_factor - dst_sub_blocks_to_skip
|
||||
|
||||
src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64)
|
||||
expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0])
|
||||
@@ -137,18 +135,14 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
|
||||
)
|
||||
src_to_dst_tensor = torch.from_numpy(src_to_dst)
|
||||
|
||||
event = self.events_pool.pop(
|
||||
) if self.events_pool else torch.npu.Event()
|
||||
event = self.events_pool.pop() if self.events_pool else torch.npu.Event()
|
||||
with torch.npu.stream(stream):
|
||||
for src_tensor, dst_tensor in zip(src_tensors, dst_tensors):
|
||||
src_key_cache, src_value_cache = src_tensor[0], src_tensor[1]
|
||||
dst_key_cache, dst_value_cache = dst_tensor[0], dst_tensor[1]
|
||||
|
||||
torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache,
|
||||
src_to_dst_tensor)
|
||||
torch.ops._C_ascend.swap_blocks(src_value_cache,
|
||||
dst_value_cache,
|
||||
src_to_dst_tensor)
|
||||
torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
|
||||
torch.ops._C_ascend.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
|
||||
|
||||
event.record(stream)
|
||||
|
||||
@@ -175,4 +169,4 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
|
||||
event = self.transfer_events.get(job_id)
|
||||
if event is not None:
|
||||
# This will block until the NPU event is complete
|
||||
event.synchronize()
|
||||
event.synchronize()
|
||||
|
||||
@@ -1,48 +1,40 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionBackend # type: ignore
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
||||
from vllm.v1.kv_offload.backends.cpu import CPUBackend
|
||||
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
|
||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
from vllm_ascend.kv_offload.cpu_npu import CpuNpuOffloadingHandler
|
||||
|
||||
|
||||
class NPUOffloadingSpec(OffloadingSpec):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig | None = None):
|
||||
super().__init__(vllm_config, kv_cache_config)
|
||||
|
||||
num_cpu_blocks = self.extra_config.get("num_cpu_blocks")
|
||||
if not num_cpu_blocks:
|
||||
raise Exception(
|
||||
"num_cpu_blocks must be specified in kv_connector_extra_config"
|
||||
)
|
||||
raise Exception("num_cpu_blocks must be specified in kv_connector_extra_config")
|
||||
self.num_cpu_blocks: int = num_cpu_blocks
|
||||
|
||||
# scheduler-side
|
||||
self._manager: Optional[OffloadingManager] = None
|
||||
self._manager: OffloadingManager | None = None
|
||||
|
||||
# worker-side
|
||||
self._handler: Optional[OffloadingHandler] = None
|
||||
self._handler: OffloadingHandler | None = None
|
||||
|
||||
def get_manager(self) -> OffloadingManager:
|
||||
if not self._manager:
|
||||
kv_events_config = self.vllm_config.kv_events_config
|
||||
enable_events = (kv_events_config is not None
|
||||
and kv_events_config.enable_kv_cache_events)
|
||||
enable_events = kv_events_config is not None and kv_events_config.enable_kv_cache_events
|
||||
self._manager = LRUOffloadingManager(
|
||||
CPUBackend(block_size=self.offloaded_block_size,
|
||||
num_blocks=self.num_cpu_blocks),
|
||||
CPUBackend(block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks),
|
||||
enable_events=enable_events,
|
||||
)
|
||||
return self._manager
|
||||
@@ -51,8 +43,7 @@ class NPUOffloadingSpec(OffloadingSpec):
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
|
||||
OffloadingHandler]]:
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||
if not self._handler:
|
||||
self._handler = CpuNpuOffloadingHandler(
|
||||
attn_backends=attn_backends,
|
||||
|
||||
@@ -16,11 +16,13 @@
|
||||
import torch
|
||||
|
||||
|
||||
def bgmv_shrink(inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
):
|
||||
return torch.ops._C_ascend.bgmv_shrink(
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
@@ -30,11 +32,13 @@ def bgmv_shrink(inputs: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
def bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
return torch.ops._C_ascend.bgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
@@ -45,16 +49,18 @@ def bgmv_expand(inputs: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, output_tensor,
|
||||
slice_offset, slice_size)
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
return torch.ops._C_ascend.bgmv_expand(
|
||||
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset, slice_size
|
||||
)
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
@@ -69,21 +75,23 @@ def sgmv_shrink(
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, scaling)
|
||||
return torch.ops._C_ascend.sgmv_shrink(
|
||||
inputs, lora_a_weights, lora_indices_tensor, seq_len_tensor, output_tensor, scaling
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
def sgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
return torch.ops._C_ascend.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
@@ -95,19 +103,20 @@ def sgmv_expand(inputs: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, slice_offset,
|
||||
slice_size)
|
||||
def sgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
return torch.ops._C_ascend.sgmv_expand(
|
||||
inputs, lora_b_weights, lora_indices_tensor, seq_len_tensor, output_tensor, slice_offset, slice_size
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||
@@ -18,26 +18,30 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||
device: Union[torch.device, str], **kwargs):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
||||
device)
|
||||
def __init__(self, max_num_batched_tokens: int, max_batches: int, device: torch.device | str, **kwargs):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
|
||||
refresh_all_lora_classes()
|
||||
self.lora_config = kwargs.get("lora_config")
|
||||
if get_ascend_device_type() == AscendDeviceType._310P or (
|
||||
self.lora_config is not None
|
||||
and self.lora_config.max_lora_rank >= 128):
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand,
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink)
|
||||
self.lora_config is not None and self.lora_config.max_lora_rank >= 128
|
||||
):
|
||||
from vllm.lora.ops.torch_ops import (
|
||||
bgmv_expand,
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink,
|
||||
sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink,
|
||||
)
|
||||
else:
|
||||
from vllm_ascend.lora.lora_ops import (bgmv_expand,
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink)
|
||||
from vllm_ascend.lora.lora_ops import (
|
||||
bgmv_expand,
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink,
|
||||
sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink,
|
||||
)
|
||||
self.bgmv_expand = bgmv_expand
|
||||
self.bgmv_expand_slice = bgmv_expand_slice
|
||||
self.bgmv_shrink = bgmv_shrink
|
||||
@@ -52,7 +56,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
# No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
self.sgmv_shrink(
|
||||
@@ -79,7 +83,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
# No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
self.sgmv_expand(
|
||||
@@ -108,7 +112,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
# No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
self.sgmv_expand_slice(
|
||||
@@ -130,8 +134,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices,
|
||||
y_offset, y_slice_size, add_inputs)
|
||||
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs)
|
||||
|
||||
def _apply_expand(
|
||||
self,
|
||||
@@ -148,13 +151,10 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
GEMM of lora'b.
|
||||
"""
|
||||
|
||||
expand_slice_fun: Callable = (self._expand_slice_prefill
|
||||
if self.is_prefill else
|
||||
self._expand_slice_decode)
|
||||
expand_slice_fun: Callable = self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
|
||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
||||
|
||||
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
||||
w_t_all: torch.Tensor, scale: float):
|
||||
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
GEMM of lora'a.
|
||||
@@ -165,14 +165,18 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
shrink_fun: Callable = (self._shrink_prefill
|
||||
if self.is_prefill else self._shrink_decode)
|
||||
shrink_fun: Callable = self._shrink_prefill if self.is_prefill else self._shrink_decode
|
||||
shrink_fun(y, x, w_t_all, scale)
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
scale: float, **kwargs):
|
||||
def add_shrink(
|
||||
self,
|
||||
y: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
@@ -194,18 +198,19 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
|
||||
scale)
|
||||
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
|
||||
|
||||
def add_expand(self,
|
||||
y: torch.Tensor,
|
||||
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
output_slices: Tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs) -> None:
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: tuple[torch.Tensor, ...] | None,
|
||||
output_slices: tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||
|
||||
@@ -229,8 +234,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = offset_start
|
||||
if lora_bias_stacked is not None:
|
||||
self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||
lora_bias_stacked)
|
||||
self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked)
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand(
|
||||
y,
|
||||
@@ -243,12 +247,9 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
offset_left += output_slices[slice_idx]
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs) -> None:
|
||||
def add_lora_embedding(
|
||||
self, y: torch.Tensor, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
@@ -263,21 +264,22 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
"""
|
||||
|
||||
# Embedding layer only need expand op
|
||||
expand_fun: Callable = (self._expand_prefill
|
||||
if self.is_prefill else self._expand_decode)
|
||||
expand_fun: Callable = self._expand_prefill if self.is_prefill else self._expand_decode
|
||||
x = x.to(torch.float32)
|
||||
expand_fun(y, x, lora_b_stacked, add_inputs)
|
||||
|
||||
def add_lora_linear(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
output_slices: Tuple[int, ...],
|
||||
*,
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
**kwargs) -> None:
|
||||
def add_lora_linear(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
output_slices: tuple[int, ...],
|
||||
*,
|
||||
buffer: tuple[torch.Tensor, ...] | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
@@ -308,27 +310,22 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
buffer = tuple(
|
||||
torch.zeros(
|
||||
(x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
for _ in range(len(output_slices)))
|
||||
torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) for _ in range(len(output_slices))
|
||||
)
|
||||
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
self.add_expand(y,
|
||||
buffer,
|
||||
lora_b_stacked,
|
||||
None,
|
||||
output_slices,
|
||||
add_inputs=True,
|
||||
**kwargs)
|
||||
self.add_expand(y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs)
|
||||
|
||||
def add_lora_logits(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> None:
|
||||
def add_lora_logits(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
@@ -350,9 +347,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
r = lora_b_stacked.size(-1)
|
||||
|
||||
if buffer is None:
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
|
||||
indices = self.sampler_indices
|
||||
|
||||
|
||||
@@ -1,91 +1,75 @@
|
||||
from typing import Optional
|
||||
|
||||
import vllm
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
from vllm.lora.layers import (
|
||||
ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA,
|
||||
)
|
||||
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
|
||||
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
from vllm_ascend.ops.linear import (
|
||||
AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear,
|
||||
)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendColumnParallelLinear
|
||||
|
||||
|
||||
class AscendMergedColumnParallelLinearWithLoRA(
|
||||
MergedColumnParallelLinearWithLoRA):
|
||||
|
||||
class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendMergedColumnParallelLinear
|
||||
|
||||
|
||||
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendRowParallelLinear
|
||||
|
||||
|
||||
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(
|
||||
packed_modules_list) == 1
|
||||
|
||||
|
||||
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
@@ -93,18 +77,28 @@ class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return (type(source_layer) is AscendQKVParallelLinear
|
||||
and len(packed_modules_list) == 3)
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1
|
||||
|
||||
|
||||
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
|
||||
|
||||
|
||||
def refresh_all_lora_classes():
|
||||
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(
|
||||
AscendMergedColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(
|
||||
AscendMergedQKVParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)
|
||||
|
||||
Reference in New Issue
Block a user