### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|
`.../distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py`
|
|
`vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py` |
| `
vllm_ascend/distributed/kv_transfer/utils/mooncake_transfer_engine.py` |
| ` vllm_ascend/distributed/kv_transfer/utils/utils.py` |
| ` vllm_ascend/kv_offload/cpu_npu.py` |
| ` vllm_ascend/kv_offload/npu.py` |
| ` vllm_ascend/lora/lora_ops.py` |
| ` vllm_ascend/lora/punica_npu.py` |
| ` vllm_ascend/lora/utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
2
mypy.ini
2
mypy.ini
@@ -1,6 +1,8 @@
|
|||||||
[mypy]
|
[mypy]
|
||||||
; warn_return_any = True
|
; warn_return_any = True
|
||||||
warn_unused_configs = True
|
warn_unused_configs = True
|
||||||
|
; disable errors about unchecked annotations for now.
|
||||||
|
disable_error_code = annotation-unchecked
|
||||||
|
|
||||||
; Suppress all missing import errors from torch_npu for mypy.
|
; Suppress all missing import errors from torch_npu for mypy.
|
||||||
[mypy-torch_npu.*]
|
[mypy-torch_npu.*]
|
||||||
|
|||||||
@@ -51,11 +51,6 @@ line-length = 120
|
|||||||
# Folder to be modified
|
# Folder to be modified
|
||||||
exclude = [
|
exclude = [
|
||||||
"tests/**",
|
"tests/**",
|
||||||
# (5)
|
|
||||||
"vllm_ascend/distributed/kv_transfer/kv_pool/**",
|
|
||||||
"vllm_ascend/distributed/kv_transfer/utils/**",
|
|
||||||
"vllm_ascend/kv_offload/**",
|
|
||||||
"vllm_ascend/lora/**",
|
|
||||||
# (7)
|
# (7)
|
||||||
"vllm_ascend/quantization/**",
|
"vllm_ascend/quantization/**",
|
||||||
"vllm_ascend/sample/*.py",
|
"vllm_ascend/sample/*.py",
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import threading
|
import threading
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
|
||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.utils.network_utils import make_zmq_socket
|
from vllm.utils.network_utils import make_zmq_socket
|
||||||
@@ -17,31 +16,27 @@ from vllm.v1.request import Request
|
|||||||
from vllm.v1.serial_utils import MsgpackDecoder
|
from vllm.v1.serial_utils import MsgpackDecoder
|
||||||
|
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler import (
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler import (
|
||||||
KVPoolScheduler, get_zmq_rpc_path_lookup)
|
KVPoolScheduler,
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import \
|
get_zmq_rpc_path_lookup,
|
||||||
KVPoolWorker
|
)
|
||||||
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import KVPoolWorker
|
||||||
|
|
||||||
|
|
||||||
class AscendStoreConnector(KVConnectorBase_V1):
|
class AscendStoreConnector(KVConnectorBase_V1):
|
||||||
|
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None):
|
||||||
def __init__(self,
|
super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
|
||||||
vllm_config: VllmConfig,
|
|
||||||
role: KVConnectorRole,
|
|
||||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
|
||||||
super().__init__(vllm_config=vllm_config,
|
|
||||||
role=role,
|
|
||||||
kv_cache_config=kv_cache_config)
|
|
||||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||||
|
|
||||||
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get("use_layerwise", False)
|
||||||
"use_layerwise", False)
|
|
||||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"consumer_is_to_put", False)
|
"consumer_is_to_put", False
|
||||||
|
)
|
||||||
|
|
||||||
connector_name = vllm_config.kv_transfer_config.kv_connector
|
connector_name = vllm_config.kv_transfer_config.kv_connector
|
||||||
if connector_name == "MooncakeConnectorStoreV1":
|
if connector_name == "MooncakeConnectorStoreV1":
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future."
|
"It is recommended to use the AscendStoreConnector, "
|
||||||
|
"as the MoonCakeStoreConnector will be removed in the future."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||||
@@ -49,8 +44,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
|||||||
self.sended_but_unfinished_reqs: set[str] = set()
|
self.sended_but_unfinished_reqs: set[str] = set()
|
||||||
|
|
||||||
if role == KVConnectorRole.SCHEDULER:
|
if role == KVConnectorRole.SCHEDULER:
|
||||||
self.connector_scheduler = KVPoolScheduler(vllm_config,
|
self.connector_scheduler = KVPoolScheduler(vllm_config, self.use_layerwise)
|
||||||
self.use_layerwise)
|
|
||||||
else:
|
else:
|
||||||
self.connector_worker = KVPoolWorker(
|
self.connector_worker = KVPoolWorker(
|
||||||
vllm_config,
|
vllm_config,
|
||||||
@@ -59,27 +53,19 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
if vllm_config.parallel_config.rank == 0:
|
if vllm_config.parallel_config.rank == 0:
|
||||||
self.lookup_server = LookupKeyServer(self.connector_worker,
|
self.lookup_server = LookupKeyServer(self.connector_worker, vllm_config, self.use_layerwise)
|
||||||
vllm_config,
|
|
||||||
self.use_layerwise)
|
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# Scheduler Side Methods
|
# Scheduler Side Methods
|
||||||
############################################################
|
############################################################
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
|
||||||
self, request: "Request",
|
|
||||||
num_computed_tokens: int) -> tuple[int, bool]:
|
|
||||||
assert self.connector_scheduler is not None
|
assert self.connector_scheduler is not None
|
||||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens)
|
||||||
request, num_computed_tokens)
|
|
||||||
|
|
||||||
def update_state_after_alloc(self, request: "Request",
|
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
|
||||||
blocks: "KVCacheBlocks",
|
|
||||||
num_external_tokens: int):
|
|
||||||
assert self.connector_scheduler is not None
|
assert self.connector_scheduler is not None
|
||||||
return self.connector_scheduler.update_state_after_alloc(
|
return self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens)
|
||||||
request, blocks, num_external_tokens)
|
|
||||||
|
|
||||||
def build_connector_meta(
|
def build_connector_meta(
|
||||||
self,
|
self,
|
||||||
@@ -92,7 +78,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
|||||||
self,
|
self,
|
||||||
request: "Request",
|
request: "Request",
|
||||||
block_ids: list[int],
|
block_ids: list[int],
|
||||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
) -> tuple[bool, dict[str, Any] | None]:
|
||||||
assert self.connector_scheduler is not None
|
assert self.connector_scheduler is not None
|
||||||
return self.connector_scheduler.request_finished(request, block_ids)
|
return self.connector_scheduler.request_finished(request, block_ids)
|
||||||
|
|
||||||
@@ -103,8 +89,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
|||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
self.connector_worker.register_kv_caches(kv_caches)
|
self.connector_worker.register_kv_caches(kv_caches)
|
||||||
|
|
||||||
def start_load_kv(self, forward_context: "ForwardContext",
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||||
**kwargs) -> None:
|
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
self.connector_worker.start_load_kv(self._get_connector_metadata())
|
self.connector_worker.start_load_kv(self._get_connector_metadata())
|
||||||
|
|
||||||
@@ -113,8 +98,9 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
|||||||
return
|
return
|
||||||
self.connector_worker.wait_for_layer_load()
|
self.connector_worker.wait_for_layer_load()
|
||||||
|
|
||||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
def save_kv_layer(
|
||||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
|
||||||
|
) -> None:
|
||||||
if not self.use_layerwise:
|
if not self.use_layerwise:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -133,17 +119,16 @@ class AscendStoreConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
self.connector_worker.wait_for_save(self._get_connector_metadata())
|
self.connector_worker.wait_for_save(self._get_connector_metadata())
|
||||||
|
|
||||||
def get_finished(self,
|
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
|
||||||
"""Get the finished recving and sending requests."""
|
"""Get the finished recving and sending requests."""
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
done_sending, done_recving = self.connector_worker.get_finished(
|
done_sending, done_recving = self.connector_worker.get_finished(
|
||||||
finished_req_ids, self._get_connector_metadata())
|
finished_req_ids, self._get_connector_metadata()
|
||||||
|
)
|
||||||
return done_sending, done_recving
|
return done_sending, done_recving
|
||||||
|
|
||||||
|
|
||||||
class LookupKeyServer:
|
class LookupKeyServer:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pool_worker: KVPoolWorker,
|
pool_worker: KVPoolWorker,
|
||||||
@@ -171,8 +156,7 @@ class LookupKeyServer:
|
|||||||
token_len = int.from_bytes(all_frames[0], byteorder="big")
|
token_len = int.from_bytes(all_frames[0], byteorder="big")
|
||||||
hash_frames = all_frames[1:]
|
hash_frames = all_frames[1:]
|
||||||
hashes_str = self.decoder.decode(hash_frames)
|
hashes_str = self.decoder.decode(hash_frames)
|
||||||
result = self.pool_worker.lookup_scheduler(
|
result = self.pool_worker.lookup_scheduler(token_len, hashes_str, self.use_layerwise)
|
||||||
token_len, hashes_str, self.use_layerwise)
|
|
||||||
response = result.to_bytes(4, "big")
|
response = result.to_bytes(4, "big")
|
||||||
self.socket.send(response)
|
self.socket.send(response)
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,15 @@ from vllm.config import ParallelConfig
|
|||||||
|
|
||||||
|
|
||||||
class Backend(ABC):
|
class Backend(ABC):
|
||||||
|
@abstractmethod
|
||||||
def __init__(self, parallel_config: ParallelConfig):
|
def __init__(self, parallel_config: ParallelConfig):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def set_device(self):
|
def set_device(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -19,11 +21,9 @@ class Backend(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def put(self, keys: list[str], addrs: list[list[int]],
|
def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
|
||||||
sizes: list[list[int]]):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, keys: list[str], addrs: list[list[int]],
|
def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
|
||||||
sizes: list[list[int]]):
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import torch
|
|||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
|
||||||
Backend
|
|
||||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||||
|
|
||||||
|
|
||||||
@@ -18,7 +17,6 @@ class MmcDirect(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class MemcacheBackend(Backend):
|
class MemcacheBackend(Backend):
|
||||||
|
|
||||||
def __init__(self, parallel_config: ParallelConfig):
|
def __init__(self, parallel_config: ParallelConfig):
|
||||||
try:
|
try:
|
||||||
from memcache_hybrid import DistributedObjectStore # type: ignore
|
from memcache_hybrid import DistributedObjectStore # type: ignore
|
||||||
@@ -26,21 +24,17 @@ class MemcacheBackend(Backend):
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install memcache by following the instructions at "
|
"Please install memcache by following the instructions at "
|
||||||
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
|
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
|
||||||
"to run vLLM with MemcacheConnector.") from e
|
"to run vLLM with MemcacheConnector."
|
||||||
|
) from e
|
||||||
try:
|
try:
|
||||||
soc_version = get_ascend_device_type()
|
soc_version = get_ascend_device_type()
|
||||||
if soc_version in {AscendDeviceType.A2}:
|
if soc_version in {AscendDeviceType.A2}:
|
||||||
import torch
|
import torch
|
||||||
from vllm.distributed import get_world_group
|
from vllm.distributed import get_world_group
|
||||||
|
|
||||||
tmp_tensor = torch.zeros(1, device="npu")
|
tmp_tensor = torch.zeros(1, device="npu")
|
||||||
output_tensor_list = [
|
output_tensor_list = [torch.empty_like(tmp_tensor) for _ in range(torch.distributed.get_world_size())]
|
||||||
torch.empty_like(tmp_tensor)
|
torch.distributed.all_gather(output_tensor_list, tmp_tensor, group=get_world_group().device_group)
|
||||||
for _ in range(torch.distributed.get_world_size())
|
|
||||||
]
|
|
||||||
torch.distributed.all_gather(
|
|
||||||
output_tensor_list,
|
|
||||||
tmp_tensor,
|
|
||||||
group=get_world_group().device_group)
|
|
||||||
self.rank = parallel_config.rank
|
self.rank = parallel_config.rank
|
||||||
self.store = DistributedObjectStore()
|
self.store = DistributedObjectStore()
|
||||||
res = self.store.init(self.rank)
|
res = self.store.init(self.rank)
|
||||||
@@ -54,8 +48,7 @@ class MemcacheBackend(Backend):
|
|||||||
logger.error("Configuration loading failed: %s", e)
|
logger.error("Configuration loading failed: %s", e)
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(
|
logger.error("An error occurred while loading the configuration: %s", exc)
|
||||||
"An error occurred while loading the configuration: %s", exc)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def set_device(self):
|
def set_device(self):
|
||||||
@@ -73,22 +66,18 @@ class MemcacheBackend(Backend):
|
|||||||
def exists(self, keys: list[str]) -> list[int]:
|
def exists(self, keys: list[str]) -> list[int]:
|
||||||
return self.store.batch_is_exist(keys)
|
return self.store.batch_is_exist(keys)
|
||||||
|
|
||||||
def get(self, key: list[str], addr: list[list[int]],
|
def get(self, key: list[str], addr: list[list[int]], size: list[list[int]]):
|
||||||
size: list[list[int]]):
|
|
||||||
try:
|
try:
|
||||||
res = self.store.batch_get_into_layers(key, addr, size,
|
res = self.store.batch_get_into_layers(key, addr, size, MmcDirect.COPY_G2L.value)
|
||||||
MmcDirect.COPY_G2L.value)
|
|
||||||
for value in res:
|
for value in res:
|
||||||
if value != 0:
|
if value != 0:
|
||||||
logger.error(f"Failed to get key {key},res:{res}")
|
logger.error(f"Failed to get key {key},res:{res}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get key {key}. {e}")
|
logger.error(f"Failed to get key {key}. {e}")
|
||||||
|
|
||||||
def put(self, key: list[str], addr: list[list[int]],
|
def put(self, key: list[str], addr: list[list[int]], size: list[list[int]]):
|
||||||
size: list[list[int]]):
|
|
||||||
try:
|
try:
|
||||||
res = self.store.batch_put_from_layers(key, addr, size,
|
res = self.store.batch_put_from_layers(key, addr, size, MmcDirect.COPY_L2G.value)
|
||||||
MmcDirect.COPY_L2G.value)
|
|
||||||
for value in res:
|
for value in res:
|
||||||
if value != 0:
|
if value != 0:
|
||||||
logger.error(f"Failed to get key {key},res:{res}")
|
logger.error(f"Failed to get key {key},res:{res}")
|
||||||
|
|||||||
@@ -2,10 +2,9 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
|
||||||
|
import torch
|
||||||
|
|
||||||
# Third Party
|
# Third Party
|
||||||
from mooncake.store import ReplicateConfig # type: ignore
|
from mooncake.store import ReplicateConfig # type: ignore
|
||||||
@@ -13,17 +12,14 @@ from vllm.config import ParallelConfig
|
|||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.utils.network_utils import get_ip
|
from vllm.utils.network_utils import get_ip
|
||||||
|
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
|
||||||
Backend
|
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te
|
||||||
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import \
|
|
||||||
global_te
|
|
||||||
|
|
||||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||||
|
|
||||||
|
|
||||||
class MooncakeBackend(Backend):
|
class MooncakeBackend(Backend):
|
||||||
|
|
||||||
def __init__(self, parallel_config: ParallelConfig):
|
def __init__(self, parallel_config: ParallelConfig):
|
||||||
try:
|
try:
|
||||||
from mooncake.store import MooncakeDistributedStore # type: ignore
|
from mooncake.store import MooncakeDistributedStore # type: ignore
|
||||||
@@ -31,23 +27,25 @@ class MooncakeBackend(Backend):
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please install mooncake by following the instructions at "
|
"Please install mooncake by following the instructions at "
|
||||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||||
"to run vLLM with MooncakeConnector.") from e
|
"to run vLLM with MooncakeConnector."
|
||||||
|
) from e
|
||||||
self.config = MooncakeStoreConfig.load_from_env()
|
self.config = MooncakeStoreConfig.load_from_env()
|
||||||
self.store = MooncakeDistributedStore()
|
self.store = MooncakeDistributedStore()
|
||||||
self.rank = parallel_config.rank
|
self.rank = parallel_config.rank
|
||||||
if self.config.protocol == "ascend":
|
if self.config.protocol == "ascend":
|
||||||
local_hostname = get_ip()
|
local_hostname = get_ip()
|
||||||
transfer_engine = global_te.get_transfer_engine(local_hostname,
|
transfer_engine = global_te.get_transfer_engine(local_hostname, device_name=None)
|
||||||
device_name=None)
|
self.local_seg = local_hostname + ":" + str(transfer_engine.get_rpc_port())
|
||||||
self.local_seg = local_hostname + ":" + str(
|
ret = self.store.setup(
|
||||||
transfer_engine.get_rpc_port())
|
self.local_seg,
|
||||||
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
self.config.metadata_server,
|
||||||
self.config.global_segment_size,
|
self.config.global_segment_size,
|
||||||
self.config.local_buffer_size,
|
self.config.local_buffer_size,
|
||||||
self.config.protocol,
|
self.config.protocol,
|
||||||
self.config.device_name,
|
self.config.device_name,
|
||||||
self.config.master_server_address,
|
self.config.master_server_address,
|
||||||
transfer_engine.get_engine())
|
transfer_engine.get_engine(),
|
||||||
|
)
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
msg = "Initialize mooncake failed."
|
msg = "Initialize mooncake failed."
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
@@ -63,25 +61,21 @@ class MooncakeBackend(Backend):
|
|||||||
def exists(self, keys: list[str]) -> list[int]:
|
def exists(self, keys: list[str]) -> list[int]:
|
||||||
return self.store.batch_is_exist(keys)
|
return self.store.batch_is_exist(keys)
|
||||||
|
|
||||||
def put(self, keys: list[str], addrs: list[list[int]],
|
def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
|
||||||
sizes: list[list[int]]):
|
|
||||||
try:
|
try:
|
||||||
config = ReplicateConfig()
|
config = ReplicateConfig()
|
||||||
config.preferred_segment = self.local_seg
|
config.preferred_segment = self.local_seg
|
||||||
config.prefer_alloc_in_same_node = True
|
config.prefer_alloc_in_same_node = True
|
||||||
res = self.store.batch_put_from_multi_buffers(
|
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes, config)
|
||||||
keys, addrs, sizes, config)
|
|
||||||
for value in res:
|
for value in res:
|
||||||
if value < 0:
|
if value < 0:
|
||||||
logger.error(f"Failed to put key {keys},res:{res}")
|
logger.error(f"Failed to put key {keys},res:{res}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to put key {keys},error:{e}")
|
logger.error(f"Failed to put key {keys},error:{e}")
|
||||||
|
|
||||||
def get(self, keys: list[str], addrs: list[list[int]],
|
def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
|
||||||
sizes: list[list[int]]):
|
|
||||||
try:
|
try:
|
||||||
res = self.store.batch_get_into_multi_buffers(
|
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes, True)
|
||||||
keys, addrs, sizes, True)
|
|
||||||
for value in res:
|
for value in res:
|
||||||
if value < 0:
|
if value < 0:
|
||||||
logger.error(f"Failed to get key {keys}, res:{res}")
|
logger.error(f"Failed to get key {keys}, res:{res}")
|
||||||
@@ -92,7 +86,7 @@ class MooncakeBackend(Backend):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MooncakeStoreConfig:
|
class MooncakeStoreConfig:
|
||||||
metadata_server: str
|
metadata_server: str
|
||||||
global_segment_size: Union[int, str]
|
global_segment_size: int | str
|
||||||
local_buffer_size: int
|
local_buffer_size: int
|
||||||
protocol: str
|
protocol: str
|
||||||
device_name: str
|
device_name: str
|
||||||
@@ -105,20 +99,19 @@ class MooncakeStoreConfig:
|
|||||||
return MooncakeStoreConfig(
|
return MooncakeStoreConfig(
|
||||||
metadata_server=config.get("metadata_server"),
|
metadata_server=config.get("metadata_server"),
|
||||||
global_segment_size=_parse_global_segment_size(
|
global_segment_size=_parse_global_segment_size(
|
||||||
config.get("global_segment_size",
|
config.get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE)
|
||||||
DEFAULT_GLOBAL_SEGMENT_SIZE)),
|
),
|
||||||
local_buffer_size=_parse_global_segment_size(
|
local_buffer_size=_parse_global_segment_size(config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)),
|
||||||
config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)),
|
|
||||||
protocol=config.get("protocol", "ascend"),
|
protocol=config.get("protocol", "ascend"),
|
||||||
device_name=config.get("device_name", ""),
|
device_name=config.get("device_name", ""),
|
||||||
master_server_address=config.get("master_server_address"))
|
master_server_address=config.get("master_server_address"),
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_from_env() -> "MooncakeStoreConfig":
|
def load_from_env() -> "MooncakeStoreConfig":
|
||||||
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||||
if not config_path:
|
if not config_path:
|
||||||
raise ValueError(
|
raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
|
||||||
return MooncakeStoreConfig.from_file(config_path)
|
return MooncakeStoreConfig.from_file(config_path)
|
||||||
|
|
||||||
|
|
||||||
@@ -143,35 +136,32 @@ def _parse_global_segment_size(value) -> int:
|
|||||||
try:
|
try:
|
||||||
return int(value)
|
return int(value)
|
||||||
except (TypeError, ValueError) as e:
|
except (TypeError, ValueError) as e:
|
||||||
raise TypeError(
|
raise TypeError(f"Unsupported type for global_segment_size: {type(value)}") from e
|
||||||
f"Unsupported type for global_segment_size: {type(value)}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
cleaned_input = value.strip().lower()
|
cleaned_input = value.strip().lower()
|
||||||
if not cleaned_input:
|
if not cleaned_input:
|
||||||
raise ValueError("global segment size cannot be empty.")
|
raise ValueError("global segment size cannot be empty.")
|
||||||
|
|
||||||
UNIT_MULTIPLIERS = {
|
UNIT_MULTIPLIERS = {
|
||||||
'gb': 1024**3, # 1 GB = 1024^3 bytes
|
"gb": 1024**3, # 1 GB = 1024^3 bytes
|
||||||
'mb': 1024**2, # 1 MB = 1024^2 bytes
|
"mb": 1024**2, # 1 MB = 1024^2 bytes
|
||||||
'kb': 1024, # 1 KB = 1024 bytes
|
"kb": 1024, # 1 KB = 1024 bytes
|
||||||
'b': 1 # 1 B = 1 byte
|
"b": 1, # 1 B = 1 byte
|
||||||
}
|
}
|
||||||
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
|
pattern = r"^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$"
|
||||||
match = re.match(pattern, cleaned_input)
|
match = re.match(pattern, cleaned_input)
|
||||||
|
|
||||||
if not match:
|
if not match:
|
||||||
raise ValueError(f"Invalid format: '{value}'")
|
raise ValueError(f"Invalid format: '{value}'")
|
||||||
|
|
||||||
number_str = match.group(1)
|
number_str = match.group(1)
|
||||||
unit = match.group(2) or 'b'
|
unit = match.group(2) or "b"
|
||||||
|
|
||||||
multiplier = UNIT_MULTIPLIERS[unit]
|
multiplier = UNIT_MULTIPLIERS[unit]
|
||||||
return _convert_to_bytes(number_str, multiplier, value)
|
return _convert_to_bytes(number_str, multiplier, value)
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_bytes(number_str: str, multiplier: int,
|
def _convert_to_bytes(number_str: str, multiplier: int, original_input: str) -> int:
|
||||||
original_input: str) -> int:
|
|
||||||
"""
|
"""
|
||||||
Convert numeric string to byte count
|
Convert numeric string to byte count
|
||||||
|
|
||||||
@@ -189,8 +179,7 @@ def _convert_to_bytes(number_str: str, multiplier: int,
|
|||||||
try:
|
try:
|
||||||
numeric_value = float(number_str)
|
numeric_value = float(number_str)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
||||||
f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
|
||||||
# Calculate byte count
|
# Calculate byte count
|
||||||
try:
|
try:
|
||||||
byte_count = int(numeric_value * multiplier)
|
byte_count = int(numeric_value * multiplier)
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterable, List, Optional, Tuple, Union
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||||
KVConnectorMetadata
|
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||||
from vllm.v1.core.sched.output import NewRequestData
|
from vllm.v1.core.sched.output import NewRequestData
|
||||||
|
|
||||||
|
|
||||||
#Parameters related to the key
|
# Parameters related to the key
|
||||||
@dataclass
|
@dataclass
|
||||||
class KeyMetadata:
|
class KeyMetadata:
|
||||||
"""name of the LLM model"""
|
"""name of the LLM model"""
|
||||||
@@ -32,23 +32,26 @@ class PoolKey:
|
|||||||
chunk_hash: str
|
chunk_hash: str
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((
|
return hash(
|
||||||
self.key_metadata.model_name,
|
(
|
||||||
self.key_metadata.head_or_tp_rank,
|
self.key_metadata.model_name,
|
||||||
self.key_metadata.pcp_rank,
|
self.key_metadata.head_or_tp_rank,
|
||||||
self.key_metadata.dcp_rank,
|
self.key_metadata.pcp_rank,
|
||||||
self.key_metadata.pp_rank,
|
self.key_metadata.dcp_rank,
|
||||||
self.chunk_hash,
|
self.key_metadata.pp_rank,
|
||||||
))
|
self.chunk_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def to_string(self):
|
def to_string(self):
|
||||||
return (
|
return (
|
||||||
f"{self.key_metadata.model_name}"
|
f"{self.key_metadata.model_name}"
|
||||||
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
|
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
|
||||||
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}"
|
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}"
|
||||||
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}")
|
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}"
|
||||||
|
)
|
||||||
|
|
||||||
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
|
def split_layers(self, num_layers: int) -> list["LayerPoolKey"]:
|
||||||
"""Split the key into multiple keys for each layer"""
|
"""Split the key into multiple keys for each layer"""
|
||||||
keys = []
|
keys = []
|
||||||
for layer_id in range(num_layers):
|
for layer_id in range(num_layers):
|
||||||
@@ -57,7 +60,8 @@ class PoolKey:
|
|||||||
self.key_metadata,
|
self.key_metadata,
|
||||||
self.chunk_hash,
|
self.chunk_hash,
|
||||||
layer_id,
|
layer_id,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
|
||||||
@@ -68,14 +72,16 @@ class LayerPoolKey(PoolKey):
|
|||||||
layer_id: int
|
layer_id: int
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((
|
return hash(
|
||||||
self.key_metadata.model_name,
|
(
|
||||||
self.key_metadata.head_or_tp_rank,
|
self.key_metadata.model_name,
|
||||||
self.key_metadata.pcp_rank,
|
self.key_metadata.head_or_tp_rank,
|
||||||
self.key_metadata.dcp_rank,
|
self.key_metadata.pcp_rank,
|
||||||
self.chunk_hash,
|
self.key_metadata.dcp_rank,
|
||||||
self.layer_id,
|
self.chunk_hash,
|
||||||
))
|
self.layer_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def to_string(self):
|
def to_string(self):
|
||||||
return (
|
return (
|
||||||
@@ -85,10 +91,8 @@ class LayerPoolKey(PoolKey):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChunkedTokenDatabase():
|
class ChunkedTokenDatabase:
|
||||||
|
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, partitions: list[int] | None):
|
||||||
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
|
|
||||||
partitions: Optional[List[int]]):
|
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.use_mla = use_mla
|
self.use_mla = use_mla
|
||||||
@@ -96,9 +100,7 @@ class ChunkedTokenDatabase():
|
|||||||
self.block_len: list[int] = []
|
self.block_len: list[int] = []
|
||||||
self.partitions = partitions
|
self.partitions = partitions
|
||||||
|
|
||||||
def _make_key_by_hash(self,
|
def _make_key_by_hash(self, chunk_hash: str, layer_id: int | None = None):
|
||||||
chunk_hash: str,
|
|
||||||
layer_id: Optional[int] = None):
|
|
||||||
assert self.metadata is not None
|
assert self.metadata is not None
|
||||||
return PoolKey(
|
return PoolKey(
|
||||||
self.metadata,
|
self.metadata,
|
||||||
@@ -116,8 +118,7 @@ class ChunkedTokenDatabase():
|
|||||||
size_list = []
|
size_list = []
|
||||||
block_id = block_ids[start // self.block_size]
|
block_id = block_ids[start // self.block_size]
|
||||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||||
block_len = (self.block_len[index % 2]
|
block_len = self.block_len[index % 2] if self.use_mla else self.block_len[0]
|
||||||
if self.use_mla else self.block_len[0])
|
|
||||||
|
|
||||||
addr = base_addr + block_id * block_len
|
addr = base_addr + block_id * block_len
|
||||||
length = int(block_len / self.block_size * (end - start))
|
length = int(block_len / self.block_size * (end - start))
|
||||||
@@ -125,22 +126,17 @@ class ChunkedTokenDatabase():
|
|||||||
size_list.append(length)
|
size_list.append(length)
|
||||||
return addr_list, size_list, block_id
|
return addr_list, size_list, block_id
|
||||||
|
|
||||||
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
|
def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int):
|
||||||
layer_id: int):
|
|
||||||
block_id = block_ids[start // self.block_size]
|
block_id = block_ids[start // self.block_size]
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
addr_k = self.kv_caches_base_addr[layer_id *
|
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
|
||||||
2] + block_id * self.block_len[0]
|
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[1]
|
||||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
|
||||||
1] + block_id * self.block_len[1]
|
|
||||||
length_k = int(self.block_len[0] / self.block_size * (end - start))
|
length_k = int(self.block_len[0] / self.block_size * (end - start))
|
||||||
length_v = int(self.block_len[1] / self.block_size * (end - start))
|
length_v = int(self.block_len[1] / self.block_size * (end - start))
|
||||||
size_list = [length_k, length_v]
|
size_list = [length_k, length_v]
|
||||||
else:
|
else:
|
||||||
addr_k = self.kv_caches_base_addr[layer_id *
|
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
|
||||||
2] + block_id * self.block_len[0]
|
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[0]
|
||||||
addr_v = self.kv_caches_base_addr[layer_id * 2 +
|
|
||||||
1] + block_id * self.block_len[0]
|
|
||||||
length = int(self.block_len[0] / self.block_size * (end - start))
|
length = int(self.block_len[0] / self.block_size * (end - start))
|
||||||
size_list = [length, length]
|
size_list = [length, length]
|
||||||
addr_list = [addr_k, addr_v]
|
addr_list = [addr_k, addr_v]
|
||||||
@@ -149,9 +145,9 @@ class ChunkedTokenDatabase():
|
|||||||
def process_tokens(
|
def process_tokens(
|
||||||
self,
|
self,
|
||||||
token_len: int,
|
token_len: int,
|
||||||
block_hashes: Union[list[BlockHash], list[str]],
|
block_hashes: list[BlockHash] | list[str],
|
||||||
mask_num: int = 0,
|
mask_num: int = 0,
|
||||||
) -> Iterable[Tuple[int, int, PoolKey]]:
|
) -> Iterable[tuple[int, int, PoolKey]]:
|
||||||
"""Process the tokens and return the corresponding cache engine keys.
|
"""Process the tokens and return the corresponding cache engine keys.
|
||||||
|
|
||||||
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
|
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
|
||||||
@@ -202,10 +198,10 @@ class ChunkedTokenDatabase():
|
|||||||
start = 0
|
start = 0
|
||||||
for j, part in enumerate(self.partitions):
|
for j, part in enumerate(self.partitions):
|
||||||
# part * 2 because addr and size contain both k and v
|
# part * 2 because addr and size contain both k and v
|
||||||
end = len(addr_list) if j == len(
|
end = len(addr_list) if j == len(self.partitions) - 1 else start + part * 2
|
||||||
self.partitions) - 1 else start + part * 2
|
|
||||||
new_str = key[i].replace( # type: ignore[attr-defined]
|
new_str = key[i].replace( # type: ignore[attr-defined]
|
||||||
"@pp_rank:0", f"@pp_rank:{j}", 1)
|
"@pp_rank:0", f"@pp_rank:{j}", 1
|
||||||
|
)
|
||||||
new_key.append(new_str)
|
new_key.append(new_str)
|
||||||
new_addr.append(addr_list[start:end])
|
new_addr.append(addr_list[start:end])
|
||||||
new_size.append(size_list[start:end])
|
new_size.append(size_list[start:end])
|
||||||
@@ -213,7 +209,7 @@ class ChunkedTokenDatabase():
|
|||||||
return new_key, new_addr, new_size
|
return new_key, new_addr, new_size
|
||||||
|
|
||||||
|
|
||||||
#Parameters related to the connector metadata
|
# Parameters related to the connector metadata
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoadSpec:
|
class LoadSpec:
|
||||||
# Number of tokens cached in vLLM
|
# Number of tokens cached in vLLM
|
||||||
@@ -273,7 +269,7 @@ class RequestTracker:
|
|||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
new_block_ids: Union[tuple[list[int], ...], list[int]],
|
new_block_ids: tuple[list[int], ...] | list[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update the request tracker when a running request is
|
"""Update the request tracker when a running request is
|
||||||
scheduled again
|
scheduled again
|
||||||
@@ -286,8 +282,7 @@ class RequestTracker:
|
|||||||
elif isinstance(new_block_ids, list):
|
elif isinstance(new_block_ids, list):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}")
|
||||||
f"Unsupported new_block_ids type {type(new_block_ids)}")
|
|
||||||
self.allocated_block_ids.extend(new_block_ids)
|
self.allocated_block_ids.extend(new_block_ids)
|
||||||
|
|
||||||
|
|
||||||
@@ -302,22 +297,22 @@ class ReqMeta:
|
|||||||
|
|
||||||
block_hashes: list[BlockHash]
|
block_hashes: list[BlockHash]
|
||||||
|
|
||||||
can_save: Optional[bool] = None
|
can_save: bool | None = None
|
||||||
# load_spec
|
# load_spec
|
||||||
load_spec: Optional[LoadSpec] = None
|
load_spec: LoadSpec | None = None
|
||||||
|
|
||||||
is_last_chunk: Optional[bool] = None
|
is_last_chunk: bool | None = None
|
||||||
|
|
||||||
current_event: Optional[torch.npu.Event] = None
|
current_event: torch.npu.Event | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_request_tracker(
|
def from_request_tracker(
|
||||||
tracker: RequestTracker,
|
tracker: RequestTracker,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
load_spec: Optional[LoadSpec] = None,
|
load_spec: LoadSpec | None = None,
|
||||||
skip_save: Optional[bool] = False,
|
skip_save: bool | None = False,
|
||||||
block_hashes: list[BlockHash] = [],
|
block_hashes: list[BlockHash] | None = None,
|
||||||
is_last_chunk: Optional[bool] = None,
|
is_last_chunk: bool | None = None,
|
||||||
discard_partial_chunks: bool = True,
|
discard_partial_chunks: bool = True,
|
||||||
) -> Optional["ReqMeta"]:
|
) -> Optional["ReqMeta"]:
|
||||||
"""Create the request metadata from a request tracker.
|
"""Create the request metadata from a request tracker.
|
||||||
@@ -333,17 +328,17 @@ class ReqMeta:
|
|||||||
the request metadata if we need to perform load/save
|
the request metadata if we need to perform load/save
|
||||||
operations, None otherwise.
|
operations, None otherwise.
|
||||||
"""
|
"""
|
||||||
|
if block_hashes is None:
|
||||||
|
block_hashes = []
|
||||||
input_token_len = tracker.token_len
|
input_token_len = tracker.token_len
|
||||||
|
|
||||||
# For save operation: do not save if the following condition is met
|
# For save operation: do not save if the following condition is met
|
||||||
# 1. has already been saved before (num_saved_tokens > 0)
|
# 1. has already been saved before (num_saved_tokens > 0)
|
||||||
# 2. number of unsaved tokens is not reached the chunk boundary
|
# 2. number of unsaved tokens is not reached the chunk boundary
|
||||||
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
|
chunk_boundary = cdiv(tracker.num_saved_tokens + 1, block_size) * block_size if discard_partial_chunks else 0
|
||||||
block_size if discard_partial_chunks else 0)
|
|
||||||
# Calculate number of tokens to save based on discard_partial_chunks
|
# Calculate number of tokens to save based on discard_partial_chunks
|
||||||
# setting
|
# setting
|
||||||
num_tokens_to_save = ((input_token_len // block_size * block_size)
|
num_tokens_to_save = (input_token_len // block_size * block_size) if discard_partial_chunks else input_token_len
|
||||||
if discard_partial_chunks else input_token_len)
|
|
||||||
|
|
||||||
skip_save = skip_save or num_tokens_to_save < chunk_boundary
|
skip_save = skip_save or num_tokens_to_save < chunk_boundary
|
||||||
if skip_save and load_spec is None:
|
if skip_save and load_spec is None:
|
||||||
@@ -363,9 +358,7 @@ class ReqMeta:
|
|||||||
else:
|
else:
|
||||||
# Do not load if not in `can_load` state
|
# Do not load if not in `can_load` state
|
||||||
load_spec = None
|
load_spec = None
|
||||||
logger.debug(
|
logger.debug(f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}")
|
||||||
f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}"
|
|
||||||
)
|
|
||||||
return ReqMeta(
|
return ReqMeta(
|
||||||
req_id=tracker.req_id,
|
req_id=tracker.req_id,
|
||||||
token_len_chunk=num_tokens_to_save,
|
token_len_chunk=num_tokens_to_save,
|
||||||
@@ -378,7 +371,6 @@ class ReqMeta:
|
|||||||
|
|
||||||
|
|
||||||
class AscendConnectorMetadata(KVConnectorMetadata):
|
class AscendConnectorMetadata(KVConnectorMetadata):
|
||||||
|
|
||||||
def __init__(self, unfinished_request_ids, preempted_req_ids):
|
def __init__(self, unfinished_request_ids, preempted_req_ids):
|
||||||
self.requests = []
|
self.requests = []
|
||||||
self.unfinished_request_ids = unfinished_request_ids
|
self.unfinished_request_ids = unfinished_request_ids
|
||||||
@@ -396,10 +388,10 @@ class AscendConnectorMetadata(KVConnectorMetadata):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class LasyerMultiBlockReqMeta:
|
class LasyerMultiBlockReqMeta:
|
||||||
req_id: str
|
req_id: str
|
||||||
keys: List[LayerPoolKey]
|
keys: list[LayerPoolKey]
|
||||||
starts: List[int]
|
starts: list[int]
|
||||||
ends: list[int]
|
ends: list[int]
|
||||||
block_ids: list[int]
|
block_ids: list[int]
|
||||||
layer_id: int
|
layer_id: int
|
||||||
is_last_chunk: Optional[bool] = True
|
is_last_chunk: bool | None = True
|
||||||
current_event: Optional[torch.npu.Event] = None
|
current_event: torch.npu.Event | None = None
|
||||||
|
|||||||
@@ -7,8 +7,7 @@ from typing import Any
|
|||||||
import torch
|
import torch
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
|
||||||
Backend
|
|
||||||
|
|
||||||
# isort: off
|
# isort: off
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||||||
@@ -20,10 +19,16 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import
|
|||||||
|
|
||||||
|
|
||||||
class KVTransferThread(threading.Thread):
|
class KVTransferThread(threading.Thread):
|
||||||
|
def __init__(
|
||||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
self,
|
||||||
block_size: int, tp_rank: int, dcp_size: int,
|
m_store: Backend,
|
||||||
ready_event: threading.Event, name: str):
|
token_database: ChunkedTokenDatabase,
|
||||||
|
block_size: int,
|
||||||
|
tp_rank: int,
|
||||||
|
dcp_size: int,
|
||||||
|
ready_event: threading.Event,
|
||||||
|
name: str,
|
||||||
|
):
|
||||||
super().__init__(daemon=True, name=name)
|
super().__init__(daemon=True, name=name)
|
||||||
self.m_store = m_store
|
self.m_store = m_store
|
||||||
self.ready_event = ready_event
|
self.ready_event = ready_event
|
||||||
@@ -39,7 +44,7 @@ class KVTransferThread(threading.Thread):
|
|||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request: ReqMeta,
|
request: ReqMeta | LasyerMultiBlockReqMeta,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
self.request_queue.put(request)
|
self.request_queue.put(request)
|
||||||
|
|
||||||
@@ -98,17 +103,20 @@ class KVTransferThread(threading.Thread):
|
|||||||
|
|
||||||
|
|
||||||
class KVCacheStoreSendingThread(KVTransferThread):
|
class KVCacheStoreSendingThread(KVTransferThread):
|
||||||
|
def __init__(
|
||||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
self,
|
||||||
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
m_store: Backend,
|
||||||
kv_role: str, ready_event: threading.Event):
|
token_database: ChunkedTokenDatabase,
|
||||||
super().__init__(m_store,
|
block_size: int,
|
||||||
token_database,
|
tp_rank: int,
|
||||||
block_size,
|
dcp_size: int,
|
||||||
tp_rank,
|
put_step: int,
|
||||||
dcp_size,
|
kv_role: str,
|
||||||
ready_event,
|
ready_event: threading.Event,
|
||||||
name="KVCacheSendingThread")
|
):
|
||||||
|
super().__init__(
|
||||||
|
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheSendingThread"
|
||||||
|
)
|
||||||
self.put_step = put_step
|
self.put_step = put_step
|
||||||
self.kv_role = kv_role
|
self.kv_role = kv_role
|
||||||
self.stored_requests = defaultdict[str, int](int)
|
self.stored_requests = defaultdict[str, int](int)
|
||||||
@@ -139,16 +147,15 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
return
|
return
|
||||||
|
|
||||||
for start, end, key in self.token_database.process_tokens(
|
for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes):
|
||||||
token_len, req_meta.block_hashes):
|
|
||||||
starts.append(start)
|
starts.append(start)
|
||||||
ends.append(end)
|
ends.append(end)
|
||||||
keys.append(key.to_string())
|
keys.append(key.to_string())
|
||||||
|
|
||||||
if not self.dcp_size > 1:
|
if not self.dcp_size > 1:
|
||||||
starts = starts[self.tp_rank % self.put_step::self.put_step]
|
starts = starts[self.tp_rank % self.put_step :: self.put_step]
|
||||||
ends = ends[self.tp_rank % self.put_step::self.put_step]
|
ends = ends[self.tp_rank % self.put_step :: self.put_step]
|
||||||
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
keys = keys[self.tp_rank % self.put_step :: self.put_step]
|
||||||
|
|
||||||
if not keys:
|
if not keys:
|
||||||
self.dec_stored_request(req_id)
|
self.dec_stored_request(req_id)
|
||||||
@@ -165,8 +172,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
keys = keys[skip_block_num:]
|
keys = keys[skip_block_num:]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Storing KV cache for %d out of %d blocks "
|
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
|
||||||
"(skip_block_num=%d) for request %s",
|
|
||||||
len(keys),
|
len(keys),
|
||||||
token_len // self.block_size,
|
token_len // self.block_size,
|
||||||
skip_block_num,
|
skip_block_num,
|
||||||
@@ -183,14 +189,12 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
addrs = []
|
addrs = []
|
||||||
sizes = []
|
sizes = []
|
||||||
for index, start in enumerate(starts):
|
for index, start in enumerate(starts):
|
||||||
addr, size, _ = self.token_database.prepare_value(
|
addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
|
||||||
start, ends[index], block_ids)
|
|
||||||
addrs.append(addr)
|
addrs.append(addr)
|
||||||
sizes.append(size)
|
sizes.append(size)
|
||||||
|
|
||||||
if self.kv_role == "kv_consumer":
|
if self.kv_role == "kv_consumer":
|
||||||
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(
|
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(keys, addrs, sizes)
|
||||||
keys, addrs, sizes)
|
|
||||||
|
|
||||||
if current_event is not None:
|
if current_event is not None:
|
||||||
current_event.synchronize()
|
current_event.synchronize()
|
||||||
@@ -201,69 +205,69 @@ class KVCacheStoreSendingThread(KVTransferThread):
|
|||||||
|
|
||||||
|
|
||||||
class KVCacheStoreRecvingThread(KVTransferThread):
|
class KVCacheStoreRecvingThread(KVTransferThread):
|
||||||
|
def __init__(
|
||||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
self,
|
||||||
block_size: int, tp_rank: int, dcp_size: int,
|
m_store: Backend,
|
||||||
ready_event: threading.Event):
|
token_database: ChunkedTokenDatabase,
|
||||||
super().__init__(m_store,
|
block_size: int,
|
||||||
token_database,
|
tp_rank: int,
|
||||||
block_size,
|
dcp_size: int,
|
||||||
tp_rank,
|
ready_event: threading.Event,
|
||||||
dcp_size,
|
):
|
||||||
ready_event,
|
super().__init__(
|
||||||
name="KVCacheStoreRecvingThread")
|
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreRecvingThread"
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_request(self, req_meta: ReqMeta):
|
def _handle_request(self, req_meta: ReqMeta):
|
||||||
token_len = req_meta.load_spec.token_len # type: ignore[union-attr]
|
token_len = req_meta.load_spec.token_len # type: ignore[union-attr]
|
||||||
req_id = req_meta.req_id
|
req_id = req_meta.req_id
|
||||||
mask_num = (
|
mask_num = (
|
||||||
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
||||||
// self.block_size * self.block_size)
|
// self.block_size
|
||||||
|
* self.block_size
|
||||||
|
)
|
||||||
addr_list = []
|
addr_list = []
|
||||||
size_list = []
|
size_list = []
|
||||||
key_list = []
|
key_list = []
|
||||||
for start, end, key in self.token_database.process_tokens(
|
for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes, mask_num):
|
||||||
token_len, req_meta.block_hashes, mask_num):
|
addr, size, _ = self.token_database.prepare_value(start, end, req_meta.block_ids)
|
||||||
addr, size, _ = self.token_database.prepare_value(
|
|
||||||
start, end, req_meta.block_ids)
|
|
||||||
key_list.append(key.to_string())
|
key_list.append(key.to_string())
|
||||||
addr_list.append(addr)
|
addr_list.append(addr)
|
||||||
size_list.append(size)
|
size_list.append(size)
|
||||||
key_list_c = key_list[self.tp_rank %
|
key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
|
||||||
len(key_list):] + key_list[:self.tp_rank %
|
addr_list_c = addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
|
||||||
len(key_list)]
|
size_list_c = size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
|
||||||
addr_list_c = addr_list[self.tp_rank %
|
|
||||||
len(addr_list):] + addr_list[:self.tp_rank %
|
|
||||||
len(addr_list)]
|
|
||||||
size_list_c = size_list[self.tp_rank %
|
|
||||||
len(size_list):] + size_list[:self.tp_rank %
|
|
||||||
len(size_list)]
|
|
||||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||||
self.set_finished_request(req_id)
|
self.set_finished_request(req_id)
|
||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
class KVCacheStoreLayerSendingThread(KVTransferThread):
|
||||||
|
def __init__(
|
||||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
self,
|
||||||
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
|
m_store: Backend,
|
||||||
ready_event: threading.Event, num_layers: int):
|
token_database: ChunkedTokenDatabase,
|
||||||
super().__init__(m_store,
|
block_size: int,
|
||||||
token_database,
|
tp_rank: int,
|
||||||
block_size,
|
dcp_size: int,
|
||||||
tp_rank,
|
put_step: int,
|
||||||
dcp_size,
|
ready_event: threading.Event,
|
||||||
ready_event,
|
num_layers: int,
|
||||||
name="KVCacheStoreLayerSendingThread")
|
):
|
||||||
|
super().__init__(
|
||||||
|
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerSendingThread"
|
||||||
|
)
|
||||||
self.final_layer_id = num_layers - 1
|
self.final_layer_id = num_layers - 1
|
||||||
self.put_step = put_step
|
self.put_step = put_step
|
||||||
|
|
||||||
def add_request( # type: ignore[override]
|
def add_request( # type: ignore[override]
|
||||||
self, req_meta: ReqMeta) -> torch.Tensor:
|
self, req_meta: ReqMeta
|
||||||
|
) -> torch.Tensor:
|
||||||
self.request_queue.put(req_meta)
|
self.request_queue.put(req_meta)
|
||||||
|
|
||||||
def _handle_request( # type: ignore[override]
|
def _handle_request( # type: ignore[override]
|
||||||
self, req_meta: LasyerMultiBlockReqMeta):
|
self, req_meta: LasyerMultiBlockReqMeta
|
||||||
|
):
|
||||||
starts = req_meta.starts
|
starts = req_meta.starts
|
||||||
ends = req_meta.ends
|
ends = req_meta.ends
|
||||||
keys = req_meta.keys
|
keys = req_meta.keys
|
||||||
@@ -272,9 +276,9 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
|||||||
total_block = len(keys)
|
total_block = len(keys)
|
||||||
is_last_chunk = req_meta.is_last_chunk
|
is_last_chunk = req_meta.is_last_chunk
|
||||||
if not self.dcp_size > 1:
|
if not self.dcp_size > 1:
|
||||||
starts = starts[self.tp_rank % self.put_step::self.put_step]
|
starts = starts[self.tp_rank % self.put_step :: self.put_step]
|
||||||
ends = ends[self.tp_rank % self.put_step::self.put_step]
|
ends = ends[self.tp_rank % self.put_step :: self.put_step]
|
||||||
keys = keys[self.tp_rank % self.put_step::self.put_step]
|
keys = keys[self.tp_rank % self.put_step :: self.put_step]
|
||||||
|
|
||||||
if not keys:
|
if not keys:
|
||||||
if is_last_chunk:
|
if is_last_chunk:
|
||||||
@@ -300,7 +304,8 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
|||||||
size_list = []
|
size_list = []
|
||||||
for index, key in enumerate(key_list):
|
for index, key in enumerate(key_list):
|
||||||
addr, size = self.token_database.prepare_value_layer(
|
addr, size = self.token_database.prepare_value_layer(
|
||||||
starts[index], ends[index], req_meta.block_ids, layer_id)
|
starts[index], ends[index], req_meta.block_ids, layer_id
|
||||||
|
)
|
||||||
addr_list.append(addr)
|
addr_list.append(addr)
|
||||||
size_list.append(size)
|
size_list.append(size)
|
||||||
|
|
||||||
@@ -313,8 +318,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
|||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Storing KV cache for %d out of %d blocks "
|
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
|
||||||
"(skip_block_num=%d) for request %s",
|
|
||||||
len(keys),
|
len(keys),
|
||||||
total_block,
|
total_block,
|
||||||
skip_block_num,
|
skip_block_num,
|
||||||
@@ -323,44 +327,42 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
|
|||||||
|
|
||||||
|
|
||||||
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
class KVCacheStoreLayerRecvingThread(KVTransferThread):
|
||||||
|
def __init__(
|
||||||
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
|
self,
|
||||||
block_size: int, tp_rank: int, dcp_size: int,
|
m_store: Backend,
|
||||||
ready_event: threading.Event, get_event: threading.Event):
|
token_database: ChunkedTokenDatabase,
|
||||||
super().__init__(m_store,
|
block_size: int,
|
||||||
token_database,
|
tp_rank: int,
|
||||||
block_size,
|
dcp_size: int,
|
||||||
tp_rank,
|
ready_event: threading.Event,
|
||||||
dcp_size,
|
get_event: threading.Event,
|
||||||
ready_event,
|
):
|
||||||
name="KVCacheStoreLayerRecvingThread")
|
super().__init__(
|
||||||
|
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerRecvingThread"
|
||||||
|
)
|
||||||
self.get_event = get_event
|
self.get_event = get_event
|
||||||
|
|
||||||
def add_request( # type: ignore[override]
|
def add_request( # type: ignore[override]
|
||||||
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
|
self, req_meta: LasyerMultiBlockReqMeta
|
||||||
|
) -> torch.Tensor:
|
||||||
self.request_queue.put(req_meta)
|
self.request_queue.put(req_meta)
|
||||||
|
|
||||||
def _handle_request( # type: ignore[override]
|
def _handle_request( # type: ignore[override]
|
||||||
self, req_meta: LasyerMultiBlockReqMeta):
|
self, req_meta: LasyerMultiBlockReqMeta
|
||||||
|
):
|
||||||
addr_list = []
|
addr_list = []
|
||||||
size_list = []
|
size_list = []
|
||||||
key_list = []
|
key_list = []
|
||||||
for index, key in enumerate(req_meta.keys):
|
for index, key in enumerate(req_meta.keys):
|
||||||
addr, size = self.token_database.prepare_value_layer(
|
addr, size = self.token_database.prepare_value_layer(
|
||||||
req_meta.starts[index], req_meta.ends[index],
|
req_meta.starts[index], req_meta.ends[index], req_meta.block_ids, req_meta.layer_id
|
||||||
req_meta.block_ids, req_meta.layer_id)
|
)
|
||||||
key_list.append(key.to_string())
|
key_list.append(key.to_string())
|
||||||
addr_list.append(addr)
|
addr_list.append(addr)
|
||||||
size_list.append(size)
|
size_list.append(size)
|
||||||
key_list_c = key_list[self.tp_rank %
|
key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
|
||||||
len(key_list):] + key_list[:self.tp_rank %
|
addr_list_c = addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
|
||||||
len(key_list)]
|
size_list_c = size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
|
||||||
addr_list_c = addr_list[self.tp_rank %
|
|
||||||
len(addr_list):] + addr_list[:self.tp_rank %
|
|
||||||
len(addr_list)]
|
|
||||||
size_list_c = size_list[self.tp_rank %
|
|
||||||
len(size_list):] + size_list[:self.tp_rank %
|
|
||||||
len(size_list)]
|
|
||||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||||
|
|
||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import zmq
|
import zmq
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||||
KVConnectorMetadata
|
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.utils.network_utils import make_zmq_socket
|
from vllm.utils.network_utils import make_zmq_socket
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
@@ -14,27 +13,29 @@ from vllm.v1.request import Request
|
|||||||
from vllm.v1.serial_utils import MsgpackEncoder
|
from vllm.v1.serial_utils import MsgpackEncoder
|
||||||
|
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||||||
AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker)
|
AscendConnectorMetadata,
|
||||||
|
LoadSpec,
|
||||||
|
ReqMeta,
|
||||||
|
RequestTracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KVPoolScheduler:
|
class KVPoolScheduler:
|
||||||
|
|
||||||
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
|
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
|
||||||
self.use_layerwise = use_layerwise
|
self.use_layerwise = use_layerwise
|
||||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||||
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"consumer_is_to_load", False)
|
"consumer_is_to_load", False
|
||||||
|
)
|
||||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"consumer_is_to_put", False)
|
"consumer_is_to_put", False
|
||||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
)
|
||||||
"load_async", False)
|
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get("load_async", False)
|
||||||
self.client = LookupKeyClient(vllm_config)
|
self.client = LookupKeyClient(vllm_config)
|
||||||
# request_id -> (vllm cached tokes, kvpool cached tokens)
|
# request_id -> (vllm cached tokes, kvpool cached tokens)
|
||||||
self.load_specs: dict[str, LoadSpec] = {}
|
self.load_specs: dict[str, LoadSpec] = {}
|
||||||
self.pcp_size = getattr(vllm_config.parallel_config,
|
self.pcp_size = getattr(vllm_config.parallel_config, "prefill_context_parallel_size", 1)
|
||||||
"prefill_context_parallel_size", 1)
|
self.dcp_size = getattr(vllm_config.parallel_config, "decode_context_parallel_size", 1)
|
||||||
self.dcp_size = getattr(vllm_config.parallel_config,
|
|
||||||
"decode_context_parallel_size", 1)
|
|
||||||
|
|
||||||
self._block_size = vllm_config.cache_config.block_size
|
self._block_size = vllm_config.cache_config.block_size
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
@@ -45,9 +46,9 @@ class KVPoolScheduler:
|
|||||||
self._request_trackers: dict[str, RequestTracker] = {}
|
self._request_trackers: dict[str, RequestTracker] = {}
|
||||||
self._preempted_req_ids: set[str] = set()
|
self._preempted_req_ids: set[str] = set()
|
||||||
# Whether to discard partial chunks
|
# Whether to discard partial chunks
|
||||||
self._discard_partial_chunks = (
|
self._discard_partial_chunks = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
"discard_partial_chunks", True
|
||||||
"discard_partial_chunks", True))
|
)
|
||||||
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
|
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
|
||||||
self._unfinished_request_ids: set[str] = set()
|
self._unfinished_request_ids: set[str] = set()
|
||||||
|
|
||||||
@@ -72,13 +73,11 @@ class KVPoolScheduler:
|
|||||||
return 0, False
|
return 0, False
|
||||||
|
|
||||||
if self._discard_partial_chunks:
|
if self._discard_partial_chunks:
|
||||||
token_len = len(request.prompt_token_ids
|
token_len = len(request.prompt_token_ids) // self._block_size * self._block_size
|
||||||
) // self._block_size * self._block_size
|
|
||||||
else:
|
else:
|
||||||
token_len = len(request.prompt_token_ids)
|
token_len = len(request.prompt_token_ids)
|
||||||
|
|
||||||
num_external_hit_tokens = self.client.lookup(token_len,
|
num_external_hit_tokens = self.client.lookup(token_len, request.block_hashes)
|
||||||
request.block_hashes)
|
|
||||||
|
|
||||||
if num_external_hit_tokens == request.num_tokens:
|
if num_external_hit_tokens == request.num_tokens:
|
||||||
num_external_hit_tokens -= 1
|
num_external_hit_tokens -= 1
|
||||||
@@ -107,9 +106,7 @@ class KVPoolScheduler:
|
|||||||
|
|
||||||
return need_to_allocate, self.load_async and not self.use_layerwise
|
return need_to_allocate, self.load_async and not self.use_layerwise
|
||||||
|
|
||||||
def update_state_after_alloc(self, request: "Request",
|
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
|
||||||
blocks: "KVCacheBlocks",
|
|
||||||
num_external_tokens: int):
|
|
||||||
"""
|
"""
|
||||||
Update KVConnector state after temporary buffer alloc.
|
Update KVConnector state after temporary buffer alloc.
|
||||||
|
|
||||||
@@ -120,8 +117,7 @@ class KVPoolScheduler:
|
|||||||
if num_external_tokens > 0:
|
if num_external_tokens > 0:
|
||||||
local_block_ids = blocks.get_block_ids()[0]
|
local_block_ids = blocks.get_block_ids()[0]
|
||||||
|
|
||||||
self._unfinished_requests[request.request_id] = (request,
|
self._unfinished_requests[request.request_id] = (request, local_block_ids)
|
||||||
local_block_ids)
|
|
||||||
self._unfinished_request_ids.add(request.request_id)
|
self._unfinished_request_ids.add(request.request_id)
|
||||||
if request.request_id not in self.load_specs:
|
if request.request_id not in self.load_specs:
|
||||||
# No KV tokens from external KV cache, return
|
# No KV tokens from external KV cache, return
|
||||||
@@ -133,18 +129,20 @@ class KVPoolScheduler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
num_external_tokens > 0 and num_external_tokens
|
num_external_tokens > 0
|
||||||
== self.load_specs[request.request_id].kvpool_cached_tokens -
|
and num_external_tokens
|
||||||
self.load_specs[request.request_id].vllm_cached_tokens
|
== self.load_specs[request.request_id].kvpool_cached_tokens
|
||||||
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
|
- self.load_specs[request.request_id].vllm_cached_tokens
|
||||||
|
), (
|
||||||
|
f"Mismatch in number of tokens: {num_external_tokens} vs "
|
||||||
f"{self.load_specs[request.request_id].kvpool_cached_tokens} - "
|
f"{self.load_specs[request.request_id].kvpool_cached_tokens} - "
|
||||||
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
|
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
|
||||||
f" for request {request.request_id}")
|
f" for request {request.request_id}"
|
||||||
|
)
|
||||||
|
|
||||||
self.load_specs[request.request_id].can_load = True
|
self.load_specs[request.request_id].can_load = True
|
||||||
|
|
||||||
def build_connector_meta(
|
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
|
||||||
"""Attach the connector metadata to the request object.
|
"""Attach the connector metadata to the request object.
|
||||||
|
|
||||||
This function should NOT modify other fields in the scheduler_output
|
This function should NOT modify other fields in the scheduler_output
|
||||||
@@ -155,8 +153,7 @@ class KVPoolScheduler:
|
|||||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
force_skip_save = (self.kv_role == "kv_consumer"
|
force_skip_save = self.kv_role == "kv_consumer" and not self.consumer_is_to_put
|
||||||
and not self.consumer_is_to_put)
|
|
||||||
|
|
||||||
for finished_req_id in scheduler_output.finished_req_ids:
|
for finished_req_id in scheduler_output.finished_req_ids:
|
||||||
self._request_trackers.pop(finished_req_id, None)
|
self._request_trackers.pop(finished_req_id, None)
|
||||||
@@ -173,9 +170,7 @@ class KVPoolScheduler:
|
|||||||
for request in scheduler_output.scheduled_new_reqs:
|
for request in scheduler_output.scheduled_new_reqs:
|
||||||
# Right now, we only load KV for new requests
|
# Right now, we only load KV for new requests
|
||||||
load_spec = self.load_specs.pop(request.req_id, None)
|
load_spec = self.load_specs.pop(request.req_id, None)
|
||||||
num_tokens_to_compute = (
|
num_tokens_to_compute = request.num_computed_tokens + scheduler_output.num_scheduled_tokens[request.req_id]
|
||||||
request.num_computed_tokens +
|
|
||||||
scheduler_output.num_scheduled_tokens[request.req_id])
|
|
||||||
request_tuple = self._unfinished_requests.get(request.req_id)
|
request_tuple = self._unfinished_requests.get(request.req_id)
|
||||||
request_real = request_tuple[0] # type: ignore[index]
|
request_real = request_tuple[0] # type: ignore[index]
|
||||||
if not isinstance(request.block_ids[0], list):
|
if not isinstance(request.block_ids[0], list):
|
||||||
@@ -183,16 +178,17 @@ class KVPoolScheduler:
|
|||||||
else:
|
else:
|
||||||
unfolded_block_ids = request.block_ids[0].copy()
|
unfolded_block_ids = request.block_ids[0].copy()
|
||||||
request_tracker = RequestTracker(
|
request_tracker = RequestTracker(
|
||||||
req_id=request.req_id,
|
req_id=request.req_id,
|
||||||
token_len=num_tokens_to_compute,
|
token_len=num_tokens_to_compute,
|
||||||
allocated_block_ids=unfolded_block_ids,
|
allocated_block_ids=unfolded_block_ids,
|
||||||
num_saved_tokens=0,
|
num_saved_tokens=0,
|
||||||
)
|
)
|
||||||
self._request_trackers[request.req_id] = request_tracker
|
self._request_trackers[request.req_id] = request_tracker
|
||||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
last_chunk_tokens_num = (
|
||||||
self._block_size * self._block_size)
|
(len(request.prompt_token_ids) // self._block_size * self._block_size)
|
||||||
if self._discard_partial_chunks else len(
|
if self._discard_partial_chunks
|
||||||
request.prompt_token_ids))
|
else len(request.prompt_token_ids)
|
||||||
|
)
|
||||||
|
|
||||||
req_meta = ReqMeta.from_request_tracker(
|
req_meta = ReqMeta.from_request_tracker(
|
||||||
request_tracker,
|
request_tracker,
|
||||||
@@ -200,8 +196,7 @@ class KVPoolScheduler:
|
|||||||
load_spec=load_spec,
|
load_spec=load_spec,
|
||||||
skip_save=force_skip_save,
|
skip_save=force_skip_save,
|
||||||
block_hashes=request_real.block_hashes,
|
block_hashes=request_real.block_hashes,
|
||||||
is_last_chunk=request_tracker.token_len
|
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
|
||||||
>= last_chunk_tokens_num,
|
|
||||||
discard_partial_chunks=self._discard_partial_chunks,
|
discard_partial_chunks=self._discard_partial_chunks,
|
||||||
)
|
)
|
||||||
if req_meta is not None:
|
if req_meta is not None:
|
||||||
@@ -224,8 +219,8 @@ class KVPoolScheduler:
|
|||||||
request_tuple = self._unfinished_requests.get(req_id)
|
request_tuple = self._unfinished_requests.get(req_id)
|
||||||
request_real = request_tuple[0] # type: ignore[index]
|
request_real = request_tuple[0] # type: ignore[index]
|
||||||
num_tokens_to_compute = (
|
num_tokens_to_compute = (
|
||||||
request_real.num_computed_tokens +
|
request_real.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
)
|
||||||
request_tracker = RequestTracker(
|
request_tracker = RequestTracker(
|
||||||
req_id=req_id,
|
req_id=req_id,
|
||||||
token_len=num_tokens_to_compute,
|
token_len=num_tokens_to_compute,
|
||||||
@@ -233,18 +228,18 @@ class KVPoolScheduler:
|
|||||||
num_saved_tokens=0,
|
num_saved_tokens=0,
|
||||||
)
|
)
|
||||||
self._request_trackers[req_id] = request_tracker
|
self._request_trackers[req_id] = request_tracker
|
||||||
last_chunk_tokens_num = ((len(request_real.prompt_token_ids) //
|
last_chunk_tokens_num = (
|
||||||
self._block_size * self._block_size)
|
(len(request_real.prompt_token_ids) // self._block_size * self._block_size)
|
||||||
if self._discard_partial_chunks else len(
|
if self._discard_partial_chunks
|
||||||
request_real.prompt_token_ids))
|
else len(request_real.prompt_token_ids)
|
||||||
|
)
|
||||||
req_meta = ReqMeta.from_request_tracker(
|
req_meta = ReqMeta.from_request_tracker(
|
||||||
request_tracker,
|
request_tracker,
|
||||||
self._block_size,
|
self._block_size,
|
||||||
load_spec=load_spec,
|
load_spec=load_spec,
|
||||||
skip_save=force_skip_save,
|
skip_save=force_skip_save,
|
||||||
block_hashes=request_real.block_hashes,
|
block_hashes=request_real.block_hashes,
|
||||||
is_last_chunk=request_tracker.token_len
|
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
|
||||||
>= last_chunk_tokens_num,
|
|
||||||
discard_partial_chunks=self._discard_partial_chunks,
|
discard_partial_chunks=self._discard_partial_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -256,48 +251,44 @@ class KVPoolScheduler:
|
|||||||
if req_tuple:
|
if req_tuple:
|
||||||
request = req_tuple[0]
|
request = req_tuple[0]
|
||||||
num_current_tokens = request_tracker.token_len
|
num_current_tokens = request_tracker.token_len
|
||||||
new_token_ids = request.all_token_ids[
|
new_token_ids = request.all_token_ids[num_current_tokens : num_current_tokens + num_new_tokens]
|
||||||
num_current_tokens:num_current_tokens + num_new_tokens]
|
|
||||||
request_tracker.token_len += len(new_token_ids)
|
request_tracker.token_len += len(new_token_ids)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Request {req_id} is not in _unfinished_requests, "
|
f"Request {req_id} is not in _unfinished_requests, but it is scheduled to be cached"
|
||||||
f"but it is scheduled to be cached")
|
)
|
||||||
num_computed_token = cached_reqs.num_computed_tokens[i]
|
num_computed_token = cached_reqs.num_computed_tokens[i]
|
||||||
if num_computed_token >= len(request.prompt_token_ids):
|
if num_computed_token >= len(request.prompt_token_ids):
|
||||||
continue
|
continue
|
||||||
request_tracker.update(new_block_ids)
|
request_tracker.update(new_block_ids)
|
||||||
|
|
||||||
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
|
last_chunk_tokens_num = (
|
||||||
self._block_size * self._block_size)
|
(len(request.prompt_token_ids) // self._block_size * self._block_size)
|
||||||
if self._discard_partial_chunks else
|
if self._discard_partial_chunks
|
||||||
len(request.prompt_token_ids))
|
else len(request.prompt_token_ids)
|
||||||
|
)
|
||||||
req_meta = ReqMeta.from_request_tracker(
|
req_meta = ReqMeta.from_request_tracker(
|
||||||
request_tracker,
|
request_tracker,
|
||||||
self._block_size,
|
self._block_size,
|
||||||
load_spec=None,
|
load_spec=None,
|
||||||
skip_save=force_skip_save,
|
skip_save=force_skip_save,
|
||||||
block_hashes=request.block_hashes,
|
block_hashes=request.block_hashes,
|
||||||
is_last_chunk=request_tracker.token_len
|
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
|
||||||
>= last_chunk_tokens_num,
|
|
||||||
discard_partial_chunks=self._discard_partial_chunks,
|
discard_partial_chunks=self._discard_partial_chunks,
|
||||||
)
|
)
|
||||||
if req_meta is not None:
|
if req_meta is not None:
|
||||||
meta.add_request(req_meta)
|
meta.add_request(req_meta)
|
||||||
|
|
||||||
request_ids = [
|
request_ids = [req.req_id for req in scheduler_output.scheduled_new_reqs]
|
||||||
req.req_id for req in scheduler_output.scheduled_new_reqs
|
for request_id, (request, block_ids) in self._unfinished_requests.items():
|
||||||
]
|
|
||||||
for request_id, (request,
|
|
||||||
block_ids) in self._unfinished_requests.items():
|
|
||||||
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
|
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
|
||||||
load_spec = self.load_specs.pop(request_id, None)
|
load_spec = self.load_specs.pop(request_id, None)
|
||||||
if not load_spec:
|
if not load_spec:
|
||||||
continue
|
continue
|
||||||
num_tokens_to_compute = load_spec.kvpool_cached_tokens
|
num_tokens_to_compute = load_spec.kvpool_cached_tokens
|
||||||
if (num_tokens_to_compute % self._block_size
|
if (num_tokens_to_compute % self._block_size != 0) and (
|
||||||
!= 0) and (num_tokens_to_compute
|
num_tokens_to_compute == len(request.prompt_token_ids) - 1
|
||||||
== len(request.prompt_token_ids) - 1):
|
):
|
||||||
num_tokens_to_compute = num_tokens_to_compute + 1
|
num_tokens_to_compute = num_tokens_to_compute + 1
|
||||||
request_tracker = RequestTracker(
|
request_tracker = RequestTracker(
|
||||||
req_id=request_id,
|
req_id=request_id,
|
||||||
@@ -324,7 +315,7 @@ class KVPoolScheduler:
|
|||||||
self,
|
self,
|
||||||
request: "Request",
|
request: "Request",
|
||||||
block_ids: list[int],
|
block_ids: list[int],
|
||||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
) -> tuple[bool, dict[str, Any] | None]:
|
||||||
"""
|
"""
|
||||||
Once a request is finished, determine whether request blocks
|
Once a request is finished, determine whether request blocks
|
||||||
should be freed now or will be sent asynchronously and freed later.
|
should be freed now or will be sent asynchronously and freed later.
|
||||||
@@ -336,13 +327,11 @@ class KVPoolScheduler:
|
|||||||
return False, None
|
return False, None
|
||||||
delay_free_blocks = len(block_ids) > 0
|
delay_free_blocks = len(block_ids) > 0
|
||||||
if delay_free_blocks:
|
if delay_free_blocks:
|
||||||
logger.info("Delaying free of %d blocks for request %s",
|
logger.info("Delaying free of %d blocks for request %s", len(block_ids), request.request_id)
|
||||||
len(block_ids), request.request_id)
|
|
||||||
return delay_free_blocks, None
|
return delay_free_blocks, None
|
||||||
|
|
||||||
|
|
||||||
class LookupKeyClient:
|
class LookupKeyClient:
|
||||||
|
|
||||||
def __init__(self, vllm_config: "VllmConfig"):
|
def __init__(self, vllm_config: "VllmConfig"):
|
||||||
self.encoder = MsgpackEncoder()
|
self.encoder = MsgpackEncoder()
|
||||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||||
|
|||||||
@@ -1,37 +1,45 @@
|
|||||||
import math
|
import math
|
||||||
import threading
|
import threading
|
||||||
from typing import Dict, Generator, Optional, Type
|
from collections.abc import Callable, Generator
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (get_decode_context_model_parallel_rank,
|
from vllm.distributed import (
|
||||||
get_decode_context_model_parallel_world_size,
|
get_decode_context_model_parallel_rank,
|
||||||
get_pcp_group, get_tensor_model_parallel_rank,
|
get_decode_context_model_parallel_world_size,
|
||||||
get_tensor_model_parallel_world_size)
|
get_pcp_group,
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||||
|
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
|
||||||
Backend
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import MemcacheBackend
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import \
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import MooncakeBackend
|
||||||
MemcacheBackend
|
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import \
|
|
||||||
MooncakeBackend
|
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
|
||||||
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
|
AscendConnectorMetadata,
|
||||||
LasyerMultiBlockReqMeta, ReqMeta)
|
ChunkedTokenDatabase,
|
||||||
|
KeyMetadata,
|
||||||
|
LasyerMultiBlockReqMeta,
|
||||||
|
ReqMeta,
|
||||||
|
)
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
|
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
|
||||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
KVCacheStoreLayerRecvingThread,
|
||||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
KVCacheStoreLayerSendingThread,
|
||||||
|
KVCacheStoreRecvingThread,
|
||||||
|
KVCacheStoreSendingThread,
|
||||||
|
KVTransferThread,
|
||||||
|
)
|
||||||
|
|
||||||
backend_map: Dict[str, Type[Backend]] = {
|
backend_map: dict[str, Callable[..., Backend]] = {
|
||||||
"mooncake": MooncakeBackend,
|
"mooncake": MooncakeBackend,
|
||||||
"memcache": MemcacheBackend,
|
"memcache": MemcacheBackend,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class KVPoolWorker:
|
class KVPoolWorker:
|
||||||
#The main class for the cache engine.
|
# The main class for the cache engine.
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -42,9 +50,7 @@ class KVPoolWorker:
|
|||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
self.dp_rank = parallel_config.data_parallel_rank
|
self.dp_rank = parallel_config.data_parallel_rank
|
||||||
self.use_mla = False
|
self.use_mla = False
|
||||||
if (hasattr(model_config, "use_mla")
|
if hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla:
|
||||||
and isinstance(model_config.use_mla, bool)
|
|
||||||
and model_config.use_mla):
|
|
||||||
self.use_mla = True
|
self.use_mla = True
|
||||||
self.use_layerwise = use_layerwize
|
self.use_layerwise = use_layerwize
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
@@ -53,19 +59,16 @@ class KVPoolWorker:
|
|||||||
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size
|
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size
|
||||||
|
|
||||||
self.pcp_size = get_pcp_group().world_size
|
self.pcp_size = get_pcp_group().world_size
|
||||||
self.pcp_rank = get_pcp_group(
|
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||||
).rank_in_group if self.pcp_size > 1 else 0
|
|
||||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
|
||||||
) if self.dcp_size > 1 else 0
|
|
||||||
|
|
||||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||||
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get("load_async", False)
|
||||||
"load_async", False)
|
|
||||||
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"consumer_is_to_put", False)
|
"consumer_is_to_put", False
|
||||||
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
)
|
||||||
"backend", "mooncake")
|
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get("backend", "mooncake")
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
@@ -88,7 +91,7 @@ class KVPoolWorker:
|
|||||||
self.put_step = 1
|
self.put_step = 1
|
||||||
|
|
||||||
self.metadata = KeyMetadata(
|
self.metadata = KeyMetadata(
|
||||||
model_config.model.rstrip('/').split('/')[-1],
|
model_config.model.rstrip("/").split("/")[-1],
|
||||||
self.head_or_tp_rank,
|
self.head_or_tp_rank,
|
||||||
self.pcp_rank,
|
self.pcp_rank,
|
||||||
self.dcp_rank,
|
self.dcp_rank,
|
||||||
@@ -99,40 +102,28 @@ class KVPoolWorker:
|
|||||||
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
|
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
|
||||||
num_hidden_layers = model_config.hf_text_config.num_hidden_layers
|
num_hidden_layers = model_config.hf_text_config.num_hidden_layers
|
||||||
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
"prefill_pp_layer_partition", None)
|
"prefill_pp_layer_partition", None
|
||||||
prefill_pp_size = int(
|
)
|
||||||
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
prefill_pp_size = int(vllm_config.kv_transfer_config.kv_connector_extra_config.get("prefill_pp_size", 1))
|
||||||
"prefill_pp_size", 1))
|
|
||||||
|
|
||||||
if partition_list_str is not None:
|
if partition_list_str is not None:
|
||||||
try:
|
try:
|
||||||
partitions = [
|
partitions = [int(layer) for layer in partition_list_str.split(",")]
|
||||||
int(layer) for layer in partition_list_str.split(",")
|
|
||||||
]
|
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
raise ValueError("Invalid partition string: {}".format(
|
raise ValueError("Invalid partition string: {}".format(partition_list_str)) from err
|
||||||
partition_list_str)) from err
|
|
||||||
if len(partitions) != prefill_pp_size:
|
if len(partitions) != prefill_pp_size:
|
||||||
raise ValueError(
|
raise ValueError(f"{len(partitions)=} does not match {prefill_pp_size=}.")
|
||||||
f"{len(partitions)=} does not match {prefill_pp_size=}."
|
|
||||||
)
|
|
||||||
if sum(partitions) != num_hidden_layers:
|
if sum(partitions) != num_hidden_layers:
|
||||||
raise ValueError(
|
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
||||||
f"{sum(partitions)=} does not match {num_hidden_layers=}."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
layers_per_partition = num_hidden_layers // prefill_pp_size
|
layers_per_partition = num_hidden_layers // prefill_pp_size
|
||||||
partitions = [
|
partitions = [layers_per_partition for _ in range(prefill_pp_size)]
|
||||||
layers_per_partition for _ in range(prefill_pp_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
if remaining_layers := num_hidden_layers % prefill_pp_size:
|
if remaining_layers := num_hidden_layers % prefill_pp_size:
|
||||||
for i in range(2, remaining_layers + 2):
|
for i in range(2, remaining_layers + 2):
|
||||||
partitions[-i] += 1
|
partitions[-i] += 1
|
||||||
|
|
||||||
self.token_database = ChunkedTokenDatabase(self.metadata,
|
self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions)
|
||||||
self.block_size,
|
|
||||||
self.use_mla, partitions)
|
|
||||||
|
|
||||||
real_backend = backend_map.get(self.backend.lower())
|
real_backend = backend_map.get(self.backend.lower())
|
||||||
|
|
||||||
@@ -142,10 +133,11 @@ class KVPoolWorker:
|
|||||||
self.put_step = 1
|
self.put_step = 1
|
||||||
|
|
||||||
self.m_store = real_backend( # type: ignore[misc]
|
self.m_store = real_backend( # type: ignore[misc]
|
||||||
parallel_config)
|
parallel_config
|
||||||
|
)
|
||||||
|
|
||||||
self.kv_send_thread: Optional[KVTransferThread] = None
|
self.kv_send_thread: KVTransferThread | None = None
|
||||||
self.kv_recv_thread: Optional[KVTransferThread] = None
|
self.kv_recv_thread: KVTransferThread | None = None
|
||||||
|
|
||||||
self.finished_store_req: set[str] = set()
|
self.finished_store_req: set[str] = set()
|
||||||
|
|
||||||
@@ -162,11 +154,14 @@ class KVPoolWorker:
|
|||||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||||
self.block_len = [
|
self.block_len = [
|
||||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
|
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||||
]
|
]
|
||||||
logger.info(
|
logger.info(
|
||||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
self.num_blocks,
|
||||||
|
block_shape_norm,
|
||||||
|
block_shape_pe,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# [num_block, block_size, num_head, hidden_dim]
|
# [num_block, block_size, num_head, hidden_dim]
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
@@ -174,11 +169,9 @@ class KVPoolWorker:
|
|||||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||||
block_shape = first_kv_cache.shape[-block_rank:]
|
block_shape = first_kv_cache.shape[-block_rank:]
|
||||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
||||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
|
||||||
block_shape)
|
|
||||||
|
|
||||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape)
|
||||||
self.use_mla, first_kv_cache.shape)
|
|
||||||
|
|
||||||
self.kv_caches = kv_caches
|
self.kv_caches = kv_caches
|
||||||
self.kv_caches_base_addr = []
|
self.kv_caches_base_addr = []
|
||||||
@@ -194,8 +187,7 @@ class KVPoolWorker:
|
|||||||
ptrs.append(base_addr)
|
ptrs.append(base_addr)
|
||||||
lengths.append(region_len)
|
lengths.append(region_len)
|
||||||
else:
|
else:
|
||||||
cache_list = [cache_or_caches
|
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
|
||||||
] if self.use_mla else cache_or_caches
|
|
||||||
for cache in cache_list:
|
for cache in cache_list:
|
||||||
base_addr = cache.data_ptr()
|
base_addr = cache.data_ptr()
|
||||||
self.kv_caches_base_addr.append(base_addr)
|
self.kv_caches_base_addr.append(base_addr)
|
||||||
@@ -208,33 +200,50 @@ class KVPoolWorker:
|
|||||||
|
|
||||||
if self.use_layerwise:
|
if self.use_layerwise:
|
||||||
self.get_event = threading.Event()
|
self.get_event = threading.Event()
|
||||||
if self.kv_role in ['kv_producer', 'kv_both']:
|
if self.kv_role in ["kv_producer", "kv_both"]:
|
||||||
ready_event_sending = threading.Event()
|
ready_event_sending = threading.Event()
|
||||||
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
self.kv_send_thread = KVCacheStoreLayerSendingThread(
|
||||||
self.m_store, self.token_database, self.block_size,
|
self.m_store,
|
||||||
self.tp_rank, self.dcp_size, self.put_step,
|
self.token_database,
|
||||||
ready_event_sending, self.num_layers)
|
self.block_size,
|
||||||
|
self.tp_rank,
|
||||||
|
self.dcp_size,
|
||||||
|
self.put_step,
|
||||||
|
ready_event_sending,
|
||||||
|
self.num_layers,
|
||||||
|
)
|
||||||
self.kv_send_thread.start()
|
self.kv_send_thread.start()
|
||||||
ready_event = threading.Event()
|
ready_event = threading.Event()
|
||||||
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
|
||||||
self.m_store, self.token_database, self.block_size,
|
self.m_store,
|
||||||
self.tp_rank, self.dcp_size, ready_event, self.get_event)
|
self.token_database,
|
||||||
|
self.block_size,
|
||||||
|
self.tp_rank,
|
||||||
|
self.dcp_size,
|
||||||
|
ready_event,
|
||||||
|
self.get_event,
|
||||||
|
)
|
||||||
self.kv_recv_thread.start()
|
self.kv_recv_thread.start()
|
||||||
ready_event.wait()
|
ready_event.wait()
|
||||||
else:
|
else:
|
||||||
if self.kv_role in ['kv_producer', 'kv_both'
|
if self.kv_role in ["kv_producer", "kv_both"] or self.consumer_is_to_put:
|
||||||
] or self.consumer_is_to_put:
|
|
||||||
ready_event_sending = threading.Event()
|
ready_event_sending = threading.Event()
|
||||||
self.kv_send_thread = KVCacheStoreSendingThread(
|
self.kv_send_thread = KVCacheStoreSendingThread(
|
||||||
self.m_store, self.token_database, self.block_size,
|
self.m_store,
|
||||||
self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
|
self.token_database,
|
||||||
ready_event_sending)
|
self.block_size,
|
||||||
|
self.tp_rank,
|
||||||
|
self.dcp_size,
|
||||||
|
self.put_step,
|
||||||
|
self.kv_role,
|
||||||
|
ready_event_sending,
|
||||||
|
)
|
||||||
self.kv_send_thread.start()
|
self.kv_send_thread.start()
|
||||||
if self.load_async:
|
if self.load_async:
|
||||||
ready_event = threading.Event()
|
ready_event = threading.Event()
|
||||||
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
self.kv_recv_thread = KVCacheStoreRecvingThread(
|
||||||
self.m_store, self.token_database, self.block_size,
|
self.m_store, self.token_database, self.block_size, self.tp_rank, self.dcp_size, ready_event
|
||||||
self.tp_rank, self.dcp_size, ready_event)
|
)
|
||||||
self.kv_recv_thread.start()
|
self.kv_recv_thread.start()
|
||||||
ready_event.wait()
|
ready_event.wait()
|
||||||
|
|
||||||
@@ -243,12 +252,12 @@ class KVPoolWorker:
|
|||||||
self.layerwise_retrievers = []
|
self.layerwise_retrievers = []
|
||||||
for request in metadata.requests:
|
for request in metadata.requests:
|
||||||
load_spec = request.load_spec
|
load_spec = request.load_spec
|
||||||
if load_spec is None or not load_spec.can_load: #load =0
|
if load_spec is None or not load_spec.can_load: # load =0
|
||||||
continue
|
continue
|
||||||
token_len = request.token_len_chunk
|
token_len = request.token_len_chunk
|
||||||
if (load_spec.kvpool_cached_tokens % self.block_size
|
if (load_spec.kvpool_cached_tokens % self.block_size != 0) and (
|
||||||
!= 0) and (load_spec.kvpool_cached_tokens
|
load_spec.kvpool_cached_tokens == token_len - 1
|
||||||
== token_len - 1):
|
):
|
||||||
token_len = request.load_spec.kvpool_cached_tokens + 1
|
token_len = request.load_spec.kvpool_cached_tokens + 1
|
||||||
else:
|
else:
|
||||||
token_len = request.load_spec.kvpool_cached_tokens
|
token_len = request.load_spec.kvpool_cached_tokens
|
||||||
@@ -260,30 +269,27 @@ class KVPoolWorker:
|
|||||||
else:
|
else:
|
||||||
if self.load_async:
|
if self.load_async:
|
||||||
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
self.kv_recv_thread.add_request( # type: ignore[union-attr]
|
||||||
request, )
|
request,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
addr_list = []
|
addr_list = []
|
||||||
size_list = []
|
size_list = []
|
||||||
key_list = []
|
key_list = []
|
||||||
mask_num = (request.load_spec.vllm_cached_tokens //
|
mask_num = request.load_spec.vllm_cached_tokens // self.block_size * self.block_size
|
||||||
self.block_size * self.block_size)
|
|
||||||
for start, end, key in self.token_database.process_tokens(
|
for start, end, key in self.token_database.process_tokens(
|
||||||
token_len, request.block_hashes, mask_num):
|
token_len, request.block_hashes, mask_num
|
||||||
addr, size, _ = self.token_database.prepare_value(
|
):
|
||||||
start, end, request.block_ids)
|
addr, size, _ = self.token_database.prepare_value(start, end, request.block_ids)
|
||||||
key_list.append(key.to_string())
|
key_list.append(key.to_string())
|
||||||
addr_list.append(addr)
|
addr_list.append(addr)
|
||||||
size_list.append(size)
|
size_list.append(size)
|
||||||
key_list_c = key_list[self.tp_rank % len(
|
key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
|
||||||
key_list):] + key_list[:self.tp_rank % len(key_list)]
|
addr_list_c = (
|
||||||
addr_list_c = addr_list[self.tp_rank %
|
addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
|
||||||
len(addr_list
|
)
|
||||||
):] + addr_list[:self.tp_rank %
|
size_list_c = (
|
||||||
len(addr_list)]
|
size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
|
||||||
size_list_c = size_list[self.tp_rank %
|
)
|
||||||
len(size_list
|
|
||||||
):] + size_list[:self.tp_rank %
|
|
||||||
len(size_list)]
|
|
||||||
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
self.m_store.get(key_list_c, addr_list_c, size_list_c)
|
||||||
|
|
||||||
def wait_for_layer_load(self) -> None:
|
def wait_for_layer_load(self) -> None:
|
||||||
@@ -294,8 +300,7 @@ class KVPoolWorker:
|
|||||||
num_retrieved_tokens = ret_token_mask.sum().item()
|
num_retrieved_tokens = ret_token_mask.sum().item()
|
||||||
logger.debug(f"Retrieved {num_retrieved_tokens} tokens")
|
logger.debug(f"Retrieved {num_retrieved_tokens} tokens")
|
||||||
|
|
||||||
def save_kv_layer(self,
|
def save_kv_layer(self, connector_metadata: AscendConnectorMetadata) -> None:
|
||||||
connector_metadata: AscendConnectorMetadata) -> None:
|
|
||||||
if self.current_layer == 0:
|
if self.current_layer == 0:
|
||||||
self.layerwise_storers = []
|
self.layerwise_storers = []
|
||||||
current_event = None
|
current_event = None
|
||||||
@@ -337,14 +342,16 @@ class KVPoolWorker:
|
|||||||
|
|
||||||
request.current_event = current_event
|
request.current_event = current_event
|
||||||
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
|
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
|
||||||
request.req_id)
|
request.req_id
|
||||||
|
)
|
||||||
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
self.kv_send_thread.add_request( # type: ignore[union-attr]
|
||||||
request, )
|
request,
|
||||||
|
)
|
||||||
|
|
||||||
def retrieve_layer(
|
def retrieve_layer(
|
||||||
self,
|
self,
|
||||||
request: ReqMeta,
|
request: ReqMeta,
|
||||||
) -> Generator[Optional[torch.Tensor], None, None]:
|
) -> Generator[torch.Tensor | None, None, None]:
|
||||||
"""
|
"""
|
||||||
Retrieve the KV cache in a layerwise manner.
|
Retrieve the KV cache in a layerwise manner.
|
||||||
|
|
||||||
@@ -364,7 +371,9 @@ class KVPoolWorker:
|
|||||||
token_len = request.token_len_chunk
|
token_len = request.token_len_chunk
|
||||||
mask_num = (
|
mask_num = (
|
||||||
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
|
||||||
// self.block_size * self.block_size)
|
// self.block_size
|
||||||
|
* self.block_size
|
||||||
|
)
|
||||||
num_required_tokens = token_len - mask_num
|
num_required_tokens = token_len - mask_num
|
||||||
|
|
||||||
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
|
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
|
||||||
@@ -373,8 +382,7 @@ class KVPoolWorker:
|
|||||||
ends = []
|
ends = []
|
||||||
keys = []
|
keys = []
|
||||||
first_flag = True
|
first_flag = True
|
||||||
for start, end, key in self.token_database.process_tokens(
|
for start, end, key in self.token_database.process_tokens(token_len, request.block_hashes, mask_num):
|
||||||
token_len, request.block_hashes, mask_num):
|
|
||||||
keys_multi_layer = key.split_layers(self.num_layers)
|
keys_multi_layer = key.split_layers(self.num_layers)
|
||||||
starts.append(start)
|
starts.append(start)
|
||||||
ends.append(end)
|
ends.append(end)
|
||||||
@@ -386,16 +394,16 @@ class KVPoolWorker:
|
|||||||
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
|
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
|
||||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||||
if not first_flag:
|
if not first_flag:
|
||||||
is_finish = self.get_event.wait(timeout=3) #try---cache
|
is_finish = self.get_event.wait(timeout=3) # try---cache
|
||||||
if not is_finish:
|
if not is_finish:
|
||||||
logger.info("Layerwise get failed")
|
logger.info("Layerwise get failed")
|
||||||
self.get_event.clear()
|
self.get_event.clear()
|
||||||
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
req_meta = LasyerMultiBlockReqMeta(
|
||||||
keys_multi_chunk, starts,
|
request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id
|
||||||
ends, request.block_ids,
|
)
|
||||||
layer_id)
|
|
||||||
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
req_meta
|
||||||
|
) # type: ignore[union-attr, call-arg, arg-type]
|
||||||
first_flag = False
|
first_flag = False
|
||||||
yield None
|
yield None
|
||||||
else:
|
else:
|
||||||
@@ -405,16 +413,14 @@ class KVPoolWorker:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
retrieved_tokens = torch.sum(ret_mask)
|
retrieved_tokens = torch.sum(ret_mask)
|
||||||
logger.debug(f"Retrieved {retrieved_tokens} "
|
logger.debug(f"Retrieved {retrieved_tokens} out of {num_required_tokens} out of total {token_len} tokens")
|
||||||
f"out of {num_required_tokens} "
|
|
||||||
f"out of total {token_len} tokens")
|
|
||||||
|
|
||||||
yield ret_mask
|
yield ret_mask
|
||||||
|
|
||||||
def store_layer(
|
def store_layer(
|
||||||
self,
|
self,
|
||||||
request: ReqMeta,
|
request: ReqMeta,
|
||||||
current_event: Optional[torch.npu.Event],
|
current_event: torch.npu.Event | None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""
|
"""
|
||||||
Store the KV cache in a layerwise manner.
|
Store the KV cache in a layerwise manner.
|
||||||
@@ -439,69 +445,88 @@ class KVPoolWorker:
|
|||||||
starts = []
|
starts = []
|
||||||
ends = []
|
ends = []
|
||||||
keys = []
|
keys = []
|
||||||
for start, end, key in self.token_database.process_tokens(
|
for start, end, key in self.token_database.process_tokens(request.token_len_chunk, request.block_hashes):
|
||||||
request.token_len_chunk, request.block_hashes):
|
|
||||||
keys_multi_layer = key.split_layers(self.num_layers)
|
keys_multi_layer = key.split_layers(self.num_layers)
|
||||||
starts.append(start)
|
starts.append(start)
|
||||||
ends.append(end)
|
ends.append(end)
|
||||||
keys.append(keys_multi_layer) #[block_num,layer_num]
|
keys.append(keys_multi_layer) # [block_num,layer_num]
|
||||||
|
|
||||||
if keys:
|
if keys:
|
||||||
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
|
keys = [list(row) for row in zip(*keys)] # [layer_num,block_num]
|
||||||
for layer_id, keys_multi_chunk in enumerate(keys):
|
for layer_id, keys_multi_chunk in enumerate(keys):
|
||||||
req_meta = LasyerMultiBlockReqMeta(request.req_id,
|
req_meta = LasyerMultiBlockReqMeta(
|
||||||
keys_multi_chunk, starts,
|
request.req_id,
|
||||||
ends, request.block_ids,
|
keys_multi_chunk,
|
||||||
layer_id,
|
starts,
|
||||||
request.is_last_chunk,
|
ends,
|
||||||
current_event)
|
request.block_ids,
|
||||||
|
layer_id,
|
||||||
|
request.is_last_chunk,
|
||||||
|
current_event,
|
||||||
|
)
|
||||||
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
|
||||||
req_meta) # type: ignore[union-attr, call-arg, arg-type]
|
req_meta
|
||||||
|
) # type: ignore[union-attr, call-arg, arg-type]
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
for layer_id in range(self.num_layers):
|
for layer_id in range(self.num_layers):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
def get_finished(self,
|
def get_finished(self, finished_req_ids: set[str], meta: AscendConnectorMetadata) -> tuple[set[str], set[str]]:
|
||||||
finished_req_ids: set[str], meta:AscendConnectorMetadata) -> tuple[set[str], set[str]]:
|
|
||||||
done_sending = (
|
done_sending = (
|
||||||
self.get_and_clear_finished_requests(
|
self.get_and_clear_finished_requests(
|
||||||
finished_req_ids, meta # type: ignore[union-attr]
|
finished_req_ids,
|
||||||
) if self.kv_role in ['kv_producer', 'kv_both']
|
meta, # type: ignore[union-attr]
|
||||||
or self.consumer_is_to_put else set())
|
)
|
||||||
|
if self.kv_role in ["kv_producer", "kv_both"] or self.consumer_is_to_put
|
||||||
|
else set()
|
||||||
|
)
|
||||||
|
|
||||||
done_recving = (
|
done_recving = (
|
||||||
self.kv_recv_thread.
|
self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr]
|
||||||
get_and_clear_finished_requests( # type: ignore[union-attr]
|
)
|
||||||
) if self.load_async else set())
|
if self.load_async
|
||||||
|
else set()
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Number of completed KV cache send requests: %d, receive "
|
"Number of completed KV cache send requests: %d, receive requests: %d, tp_rank:%d",
|
||||||
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
|
len(done_sending),
|
||||||
self.tp_rank)
|
len(done_recving),
|
||||||
|
self.tp_rank,
|
||||||
|
)
|
||||||
return done_sending, done_recving
|
return done_sending, done_recving
|
||||||
|
|
||||||
def get_and_clear_finished_requests(self, finished_req_ids, meta:AscendConnectorMetadata) -> set[str]:
|
def get_and_clear_finished_requests(self, finished_req_ids, meta: AscendConnectorMetadata) -> set[str]:
|
||||||
finished_sending = set()
|
finished_sending = set()
|
||||||
for req_id in meta.preempted_req_ids:
|
for req_id in meta.preempted_req_ids:
|
||||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||||
req_id)
|
req_id
|
||||||
|
)
|
||||||
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
|
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
|
||||||
):
|
):
|
||||||
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
|
if (
|
||||||
req_id] == 0 and req_id in self.finished_store_req:
|
self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
|
||||||
|
req_id
|
||||||
|
]
|
||||||
|
== 0
|
||||||
|
and req_id in self.finished_store_req
|
||||||
|
):
|
||||||
self.finished_store_req.remove(req_id)
|
self.finished_store_req.remove(req_id)
|
||||||
finished_sending.add(req_id)
|
finished_sending.add(req_id)
|
||||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||||
req_id)
|
req_id
|
||||||
|
)
|
||||||
|
|
||||||
for req_id in finished_req_ids:
|
for req_id in finished_req_ids:
|
||||||
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
|
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
|
||||||
req_id)
|
req_id
|
||||||
|
)
|
||||||
if req_remain_jobs == 0:
|
if req_remain_jobs == 0:
|
||||||
finished_sending.add(req_id)
|
finished_sending.add(req_id)
|
||||||
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
|
||||||
req_id)
|
req_id
|
||||||
|
)
|
||||||
elif req_remain_jobs is not None:
|
elif req_remain_jobs is not None:
|
||||||
self.finished_store_req.add(req_id)
|
self.finished_store_req.add(req_id)
|
||||||
|
|
||||||
@@ -522,8 +547,7 @@ class KVPoolWorker:
|
|||||||
keys = []
|
keys = []
|
||||||
try:
|
try:
|
||||||
starts = []
|
starts = []
|
||||||
for start, end, key in self.token_database.process_tokens(
|
for start, end, key in self.token_database.process_tokens(token_len, block_hashes):
|
||||||
token_len, block_hashes):
|
|
||||||
if use_layerwise:
|
if use_layerwise:
|
||||||
keys_multi_layer = key.split_layers(self.num_layers)
|
keys_multi_layer = key.split_layers(self.num_layers)
|
||||||
for item in keys_multi_layer:
|
for item in keys_multi_layer:
|
||||||
@@ -560,8 +584,7 @@ class KVPoolWorker:
|
|||||||
keys = []
|
keys = []
|
||||||
try:
|
try:
|
||||||
starts = []
|
starts = []
|
||||||
for start, end, key in self.token_database.process_tokens(
|
for start, end, key in self.token_database.process_tokens(token_len, block_hashes):
|
||||||
token_len, block_hashes):
|
|
||||||
if use_layerwise:
|
if use_layerwise:
|
||||||
keys_multi_layer = key.split_layers(self.num_layers)
|
keys_multi_layer = key.split_layers(self.num_layers)
|
||||||
for item in keys_multi_layer:
|
for item in keys_multi_layer:
|
||||||
@@ -574,25 +597,25 @@ class KVPoolWorker:
|
|||||||
for i in range(1, min(self.tp_size, self.num_kv_head)):
|
for i in range(1, min(self.tp_size, self.num_kv_head)):
|
||||||
for item in keys:
|
for item in keys:
|
||||||
new_str = item.replace( # type: ignore[attr-defined]
|
new_str = item.replace( # type: ignore[attr-defined]
|
||||||
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
|
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1
|
||||||
|
)
|
||||||
multi_tp_keys.append(new_str)
|
multi_tp_keys.append(new_str)
|
||||||
|
|
||||||
for i in range(1, self.pp_size):
|
for i in range(1, self.pp_size):
|
||||||
for item in keys:
|
for item in keys:
|
||||||
new_str = item.replace( # type: ignore[attr-defined]
|
new_str = item.replace( # type: ignore[attr-defined]
|
||||||
"@pp_rank:0", f"@pp_rank:{i}", 1)
|
"@pp_rank:0", f"@pp_rank:{i}", 1
|
||||||
|
)
|
||||||
multi_tp_keys.append(new_str)
|
multi_tp_keys.append(new_str)
|
||||||
|
|
||||||
res = self.m_store.exists(
|
res = self.m_store.exists(multi_tp_keys) # type: ignore[assignment]
|
||||||
multi_tp_keys) # type: ignore[assignment]
|
|
||||||
num_block = len(keys)
|
num_block = len(keys)
|
||||||
if use_layerwise:
|
if use_layerwise:
|
||||||
res = self.check_all_layers_exists(res, self.num_layers)
|
res = self.check_all_layers_exists(res, self.num_layers)
|
||||||
num_block = len(keys) // self.num_layers
|
num_block = len(keys) // self.num_layers
|
||||||
multi_tp_values = [
|
multi_tp_values = [
|
||||||
res[i * num_block:(i + 1) * num_block] # type: ignore[index]
|
res[i * num_block : (i + 1) * num_block] # type: ignore[index]
|
||||||
for i in range(
|
for i in range(min(self.tp_size, self.num_kv_head) * self.pp_size)
|
||||||
min(self.tp_size, self.num_kv_head) * self.pp_size)
|
|
||||||
]
|
]
|
||||||
index = self.find_min_first_non_one_index(multi_tp_values)
|
index = self.find_min_first_non_one_index(multi_tp_values)
|
||||||
if index != -1:
|
if index != -1:
|
||||||
@@ -603,8 +626,7 @@ class KVPoolWorker:
|
|||||||
return start
|
return start
|
||||||
return end
|
return end
|
||||||
|
|
||||||
def check_all_layers_exists(self, res: list[int],
|
def check_all_layers_exists(self, res: list[int], num_layers: int) -> list[int]:
|
||||||
num_layers: int) -> list[int]:
|
|
||||||
total_chunks = len(res) // num_layers
|
total_chunks = len(res) // num_layers
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
@@ -618,7 +640,6 @@ class KVPoolWorker:
|
|||||||
|
|
||||||
def find_min_first_non_one_index(self, arr):
|
def find_min_first_non_one_index(self, arr):
|
||||||
try:
|
try:
|
||||||
return min(idx for row in arr for idx, val in enumerate(row)
|
return min(idx for row in arr for idx, val in enumerate(row) if val != 1)
|
||||||
if val != 1)
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return -1
|
return -1
|
||||||
|
|||||||
@@ -1,20 +1,17 @@
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.utils.hashing import sha256
|
from vllm.utils.hashing import sha256
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock)
|
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||||
from vllm.v1.core.single_type_kv_cache_manager import \
|
from vllm.v1.core.single_type_kv_cache_manager import get_manager_for_kv_cache_spec
|
||||||
get_manager_for_kv_cache_spec
|
|
||||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||||
from vllm.v1.metrics.stats import (PrefixCacheStats, CachingMetrics)
|
from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
|
||||||
class CPUCacheStats:
|
class CPUCacheStats:
|
||||||
|
|
||||||
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
|
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
|
||||||
self.enable_prefix_caching = enable_prefix_caching
|
self.enable_prefix_caching = enable_prefix_caching
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
@@ -27,10 +24,9 @@ class CPUCacheStats:
|
|||||||
# Log the prefix cache hit rate every 10 seconds.
|
# Log the prefix cache hit rate every 10 seconds.
|
||||||
if current_time_sec - self.time_sec >= 10:
|
if current_time_sec - self.time_sec >= 10:
|
||||||
self.time_sec = current_time_sec
|
self.time_sec = current_time_sec
|
||||||
logger.info("CPU Prefix cache hit rate: %.1f%%",
|
logger.info("CPU Prefix cache hit rate: %.1f%%", self.cpu_prefix_cache_metrics.hit_rate * 100)
|
||||||
self.cpu_prefix_cache_metrics.hit_rate * 100)
|
|
||||||
|
|
||||||
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
|
def make_prefix_cache_stats(self) -> PrefixCacheStats | None:
|
||||||
"""Get (and reset) the prefix cache stats.
|
"""Get (and reset) the prefix cache stats.
|
||||||
Returns:
|
Returns:
|
||||||
The current prefix caching stats, or None if logging is disabled.
|
The current prefix caching stats, or None if logging is disabled.
|
||||||
@@ -57,7 +53,6 @@ class CPUCacheStats:
|
|||||||
|
|
||||||
|
|
||||||
class CPUKVCacheManager:
|
class CPUKVCacheManager:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
@@ -70,30 +65,26 @@ class CPUKVCacheManager:
|
|||||||
self.num_cpu_blocks = num_cpu_blocks
|
self.num_cpu_blocks = num_cpu_blocks
|
||||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||||
self.use_eagle = use_eagle
|
self.use_eagle = use_eagle
|
||||||
self.block_pool = BlockPool(self.num_cpu_blocks, True,
|
self.block_pool = BlockPool(self.num_cpu_blocks, True, enable_kv_cache_events)
|
||||||
enable_kv_cache_events)
|
|
||||||
self.single_type_manager = get_manager_for_kv_cache_spec(
|
self.single_type_manager = get_manager_for_kv_cache_spec(
|
||||||
kv_cache_spec=kv_cache_spec,
|
kv_cache_spec=kv_cache_spec,
|
||||||
block_pool=self.block_pool,
|
block_pool=self.block_pool,
|
||||||
kv_cache_group_id=0,
|
kv_cache_group_id=0,
|
||||||
)
|
)
|
||||||
# Record kv block hashes, avoid redundant computation.
|
# Record kv block hashes, avoid redundant computation.
|
||||||
self.req_to_block_hashes: defaultdict[
|
self.req_to_block_hashes: defaultdict[str, list[BlockHash]] = defaultdict(list)
|
||||||
str, list[BlockHash]] = defaultdict(list)
|
|
||||||
# Record blocks touched in get_matched_num_and_touch().
|
# Record blocks touched in get_matched_num_and_touch().
|
||||||
self.req_to_computed_blocks: defaultdict[
|
self.req_to_computed_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list)
|
||||||
str, list[KVCacheBlock]] = defaultdict(list)
|
|
||||||
# Record the request that failed to allocate.
|
# Record the request that failed to allocate.
|
||||||
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
|
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
|
||||||
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
|
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
|
||||||
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
|
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True, log_stats=True)
|
||||||
log_stats=True)
|
|
||||||
# Record request that will be free after finish sending
|
# Record request that will be free after finish sending
|
||||||
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
|
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
|
||||||
|
|
||||||
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
|
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
|
||||||
# When the request requires prompt logprobs, we skip prefix caching.
|
# When the request requires prompt logprobs, we skip prefix caching.
|
||||||
if (request.sampling_params.prompt_logprobs is not None):
|
if request.sampling_params.prompt_logprobs is not None:
|
||||||
return 0, False
|
return 0, False
|
||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
# The block hashes for the request may already be computed
|
# The block hashes for the request may already be computed
|
||||||
@@ -119,10 +110,8 @@ class CPUKVCacheManager:
|
|||||||
|
|
||||||
# cup prefix cache status set and log
|
# cup prefix cache status set and log
|
||||||
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
|
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
|
||||||
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
|
self.cpu_cache_stats.set_cache_stats(request.num_tokens, num_computed_tokens)
|
||||||
num_computed_tokens)
|
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(self.cpu_cache_stats.prefix_cache_stats)
|
||||||
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
|
|
||||||
self.cpu_cache_stats.prefix_cache_stats)
|
|
||||||
self.cpu_cache_stats.log()
|
self.cpu_cache_stats.log()
|
||||||
|
|
||||||
return num_computed_tokens, False
|
return num_computed_tokens, False
|
||||||
@@ -130,12 +119,10 @@ class CPUKVCacheManager:
|
|||||||
def _release_ahead_touch(self, request_id: str):
|
def _release_ahead_touch(self, request_id: str):
|
||||||
computed_blocks = self.req_to_computed_blocks[request_id]
|
computed_blocks = self.req_to_computed_blocks[request_id]
|
||||||
if computed_blocks:
|
if computed_blocks:
|
||||||
self.single_type_manager.block_pool.free_blocks(
|
self.single_type_manager.block_pool.free_blocks(reversed(computed_blocks))
|
||||||
reversed(computed_blocks))
|
|
||||||
self.req_to_computed_blocks.pop(request_id, None)
|
self.req_to_computed_blocks.pop(request_id, None)
|
||||||
|
|
||||||
def allocate_slots(self, req_to_num_tokens: dict[str, int],
|
def allocate_slots(self, req_to_num_tokens: dict[str, int], unallocated_req_ids: set[str]) -> dict[str, list[int]]:
|
||||||
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
|
|
||||||
for request_id in unallocated_req_ids:
|
for request_id in unallocated_req_ids:
|
||||||
self._free_slots(request_id)
|
self._free_slots(request_id)
|
||||||
req_to_new_blocks = {}
|
req_to_new_blocks = {}
|
||||||
@@ -143,44 +130,34 @@ class CPUKVCacheManager:
|
|||||||
if self.req_failed_to_allocate[request_id]:
|
if self.req_failed_to_allocate[request_id]:
|
||||||
continue
|
continue
|
||||||
new_computed_blocks = self.req_to_computed_blocks[request_id]
|
new_computed_blocks = self.req_to_computed_blocks[request_id]
|
||||||
num_blocks_to_allocate = (
|
num_blocks_to_allocate = self.single_type_manager.get_num_blocks_to_allocate(
|
||||||
self.single_type_manager.get_num_blocks_to_allocate(
|
request_id=request_id,
|
||||||
request_id=request_id,
|
num_tokens=num_tokens,
|
||||||
num_tokens=num_tokens,
|
new_computed_blocks=new_computed_blocks,
|
||||||
new_computed_blocks=new_computed_blocks,
|
)
|
||||||
))
|
|
||||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||||
self._release_ahead_touch(request_id)
|
self._release_ahead_touch(request_id)
|
||||||
self.req_failed_to_allocate[request_id] = True
|
self.req_failed_to_allocate[request_id] = True
|
||||||
continue
|
continue
|
||||||
# Append the new computed blocks to the request blocks until now to
|
# Append the new computed blocks to the request blocks until now to
|
||||||
# avoid the case where the new blocks cannot be allocated.
|
# avoid the case where the new blocks cannot be allocated.
|
||||||
self.single_type_manager.save_new_computed_blocks(
|
self.single_type_manager.save_new_computed_blocks(request_id, new_computed_blocks)
|
||||||
request_id, new_computed_blocks)
|
|
||||||
# Allocate new blocks but do not cache now.
|
# Allocate new blocks but do not cache now.
|
||||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
new_blocks = self.single_type_manager.allocate_new_blocks(request_id, num_tokens)
|
||||||
request_id, num_tokens)
|
|
||||||
self.req_to_num_tokens[request_id] = num_tokens
|
self.req_to_num_tokens[request_id] = num_tokens
|
||||||
# No need to release ref_cnt because we use officially.
|
# No need to release ref_cnt because we use officially.
|
||||||
self.req_to_computed_blocks.pop(request_id, None)
|
self.req_to_computed_blocks.pop(request_id, None)
|
||||||
req_to_new_blocks[request_id] = [
|
req_to_new_blocks[request_id] = [block.block_id for block in new_computed_blocks + new_blocks]
|
||||||
block.block_id for block in new_computed_blocks + new_blocks
|
|
||||||
]
|
|
||||||
return req_to_new_blocks
|
return req_to_new_blocks
|
||||||
|
|
||||||
def record_request_cache_and_free_slots(self, request: Request):
|
def record_request_cache_and_free_slots(self, request: Request):
|
||||||
logger.debug(
|
logger.debug(f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager")
|
||||||
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
|
|
||||||
)
|
|
||||||
self.req_to_free[request.request_id] = request
|
self.req_to_free[request.request_id] = request
|
||||||
|
|
||||||
def cache_and_free_slots(self, request_id: str):
|
def cache_and_free_slots(self, request_id: str):
|
||||||
logger.debug(
|
logger.debug(f"Cache and free slots for request {request_id} in cpu_kv_cache_manager")
|
||||||
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
|
|
||||||
)
|
|
||||||
if request_id not in self.req_to_free:
|
if request_id not in self.req_to_free:
|
||||||
logger.Error(
|
logger.Error(f"request {request_id} not in req_to_free, maybe bug!")
|
||||||
f"request {request_id} not in req_to_free, maybe bug!")
|
|
||||||
return
|
return
|
||||||
request = self.req_to_free[request_id]
|
request = self.req_to_free[request_id]
|
||||||
if not self.req_failed_to_allocate[request_id]:
|
if not self.req_failed_to_allocate[request_id]:
|
||||||
@@ -189,8 +166,7 @@ class CPUKVCacheManager:
|
|||||||
self.req_to_num_tokens[request_id],
|
self.req_to_num_tokens[request_id],
|
||||||
)
|
)
|
||||||
self._free_slots(request_id)
|
self._free_slots(request_id)
|
||||||
logger.debug(
|
logger.debug(f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
|
||||||
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
|
|
||||||
del self.req_to_free[request_id]
|
del self.req_to_free[request_id]
|
||||||
|
|
||||||
def _free_slots(self, request_id: str):
|
def _free_slots(self, request_id: str):
|
||||||
|
|||||||
@@ -5,15 +5,15 @@ import queue
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.attention.layer import Attention, MLAAttention
|
from vllm.attention.layer import Attention, MLAAttention
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
|
||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
|
||||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
@@ -23,12 +23,14 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
|
||||||
|
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.metadata import (
|
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.metadata import (
|
||||||
MetadataServer, MetadataServerProc, MLAConfig)
|
MetadataServer,
|
||||||
|
MetadataServerProc,
|
||||||
|
MLAConfig,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.attention.backend import AttentionMetadata #type: ignore
|
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
@@ -59,20 +61,15 @@ class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
|
|||||||
|
|
||||||
|
|
||||||
class CPUOffloadingConnector(KVConnectorBase_V1):
|
class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||||
|
def __init__(
|
||||||
def __init__(self,
|
self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None
|
||||||
vllm_config: VllmConfig,
|
):
|
||||||
role: KVConnectorRole,
|
|
||||||
kv_cache_config: Optional["KVCacheConfig"] = None):
|
|
||||||
self._connector_metadata = CPUOffloadingConnectorMetadata(requests={}, finished_req_ids=set())
|
self._connector_metadata = CPUOffloadingConnectorMetadata(requests={}, finished_req_ids=set())
|
||||||
if not vllm_config.cache_config.enable_prefix_caching:
|
if not vllm_config.cache_config.enable_prefix_caching:
|
||||||
self.connector_scheduler: Optional[
|
self.connector_scheduler: CPUOffloadingConnectorScheduler | None = None
|
||||||
CPUOffloadingConnectorScheduler] = None
|
self.connector_worker: CPUOffloadingConnectorWorker | None = None
|
||||||
self.connector_worker: Optional[
|
|
||||||
CPUOffloadingConnectorWorker] = None
|
|
||||||
elif role == KVConnectorRole.SCHEDULER:
|
elif role == KVConnectorRole.SCHEDULER:
|
||||||
self.connector_scheduler = CPUOffloadingConnectorScheduler(
|
self.connector_scheduler = CPUOffloadingConnectorScheduler(vllm_config)
|
||||||
vllm_config)
|
|
||||||
self.connector_worker = None
|
self.connector_worker = None
|
||||||
elif role == KVConnectorRole.WORKER:
|
elif role == KVConnectorRole.WORKER:
|
||||||
self.connector_scheduler = None
|
self.connector_scheduler = None
|
||||||
@@ -82,11 +79,9 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
|
|||||||
# Worker-side methods
|
# Worker-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
|
|
||||||
def bind_connector_metadata(
|
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
|
||||||
if self.connector_worker is not None:
|
if self.connector_worker is not None:
|
||||||
assert isinstance(connector_metadata,
|
assert isinstance(connector_metadata, CPUOffloadingConnectorMetadata)
|
||||||
CPUOffloadingConnectorMetadata)
|
|
||||||
self.connector_worker.bind_connector_metadata(connector_metadata)
|
self.connector_worker.bind_connector_metadata(connector_metadata)
|
||||||
|
|
||||||
def clear_connector_metadata(self) -> None:
|
def clear_connector_metadata(self) -> None:
|
||||||
@@ -97,8 +92,7 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
|
|||||||
if self.connector_worker is not None:
|
if self.connector_worker is not None:
|
||||||
self.connector_worker.register_kv_caches(kv_caches)
|
self.connector_worker.register_kv_caches(kv_caches)
|
||||||
|
|
||||||
def start_load_kv(self, forward_context: "ForwardContext",
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||||
**kwargs) -> None:
|
|
||||||
if self.connector_worker is not None:
|
if self.connector_worker is not None:
|
||||||
self.connector_worker.start_load_kv()
|
self.connector_worker.start_load_kv()
|
||||||
|
|
||||||
@@ -106,53 +100,42 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
|
|||||||
if self.connector_worker is not None:
|
if self.connector_worker is not None:
|
||||||
self.connector_worker.wait_for_layer_load()
|
self.connector_worker.wait_for_layer_load()
|
||||||
|
|
||||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
def save_kv_layer(
|
||||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
|
||||||
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def wait_for_save(self):
|
def wait_for_save(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_finished(
|
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str] | None, set[str] | None]:
|
||||||
self, finished_req_ids: set[str]
|
|
||||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
return self.connector_worker.get_finished(), None
|
return self.connector_worker.get_finished(), None
|
||||||
|
|
||||||
# Scheduler-side methods
|
# Scheduler-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
|
||||||
self, request: "Request",
|
|
||||||
num_computed_tokens: int) -> tuple[int, bool]:
|
|
||||||
if self.connector_scheduler is not None:
|
if self.connector_scheduler is not None:
|
||||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens)
|
||||||
request, num_computed_tokens)
|
|
||||||
return 0, False
|
return 0, False
|
||||||
|
|
||||||
def update_state_after_alloc(self, request: "Request",
|
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
|
||||||
blocks: "KVCacheBlocks",
|
|
||||||
num_external_tokens: int):
|
|
||||||
if self.connector_scheduler is not None:
|
if self.connector_scheduler is not None:
|
||||||
return self.connector_scheduler.update_state_after_alloc(request)
|
return self.connector_scheduler.update_state_after_alloc(request)
|
||||||
|
|
||||||
def build_connector_meta(
|
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
|
||||||
if self.connector_scheduler is not None:
|
if self.connector_scheduler is not None:
|
||||||
return self.connector_scheduler.build_connector_meta(
|
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||||
scheduler_output)
|
|
||||||
return KVConnectorMetadata()
|
return KVConnectorMetadata()
|
||||||
|
|
||||||
def request_finished(
|
def request_finished(self, request: "Request", block_ids: list[int]) -> tuple[bool, dict[str, Any] | None]:
|
||||||
self, request: "Request",
|
|
||||||
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
|
|
||||||
if self.connector_scheduler is not None:
|
if self.connector_scheduler is not None:
|
||||||
self.connector_scheduler.request_finished(request)
|
self.connector_scheduler.request_finished(request)
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
class CPUOffloadingConnectorScheduler:
|
class CPUOffloadingConnectorScheduler:
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
logger.info("init CPUOffloadingConnectorScheduler")
|
logger.info("init CPUOffloadingConnectorScheduler")
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
@@ -165,22 +148,17 @@ class CPUOffloadingConnectorScheduler:
|
|||||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||||
self.zmq_rpc_client.call("post_init")
|
self.zmq_rpc_client.call("post_init")
|
||||||
if vllm_config.kv_transfer_config is not None:
|
if vllm_config.kv_transfer_config is not None:
|
||||||
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
|
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config("swap_in_threshold", 0)
|
||||||
"swap_in_threshold", 0)
|
|
||||||
else:
|
else:
|
||||||
self.swap_in_threshold = 0
|
self.swap_in_threshold = 0
|
||||||
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
|
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(self, ori_request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
|
||||||
self, ori_request: "Request",
|
|
||||||
num_computed_tokens: int) -> tuple[int, bool]:
|
|
||||||
request = copy.deepcopy(ori_request)
|
request = copy.deepcopy(ori_request)
|
||||||
request.get_hash_new_full_blocks = None
|
request.get_hash_new_full_blocks = None
|
||||||
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
|
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call("get_matched_num_and_touch", request)
|
||||||
"get_matched_num_and_touch", request)
|
|
||||||
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
|
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
|
||||||
self.num_cpu_computed_tokens[
|
self.num_cpu_computed_tokens[request.request_id] = num_cpu_computed_tokens
|
||||||
request.request_id] = num_cpu_computed_tokens
|
|
||||||
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
|
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
|
||||||
return num_cpu_computed_tokens - num_computed_tokens, load_async
|
return num_cpu_computed_tokens - num_computed_tokens, load_async
|
||||||
else:
|
else:
|
||||||
@@ -189,29 +167,22 @@ class CPUOffloadingConnectorScheduler:
|
|||||||
def update_state_after_alloc(self, request: "Request"):
|
def update_state_after_alloc(self, request: "Request"):
|
||||||
self.allocated_req_ids.add(request.request_id)
|
self.allocated_req_ids.add(request.request_id)
|
||||||
|
|
||||||
def build_connector_meta(
|
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
|
||||||
num_tokens = {}
|
num_tokens = {}
|
||||||
# process scheduled_new_reqs
|
# process scheduled_new_reqs
|
||||||
for req in scheduler_output.scheduled_new_reqs:
|
for req in scheduler_output.scheduled_new_reqs:
|
||||||
req_id = req.req_id
|
req_id = req.req_id
|
||||||
num_tokens[req_id] = (
|
num_tokens[req_id] = req.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]
|
||||||
req.num_computed_tokens +
|
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
|
||||||
|
|
||||||
# process scheduled_cached_reqs
|
# process scheduled_cached_reqs
|
||||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||||
num_tokens[req_id] = (
|
num_tokens[req_id] = cached_reqs.num_computed_tokens[idx] + scheduler_output.num_scheduled_tokens[req_id]
|
||||||
cached_reqs.num_computed_tokens[idx] +
|
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
|
||||||
|
|
||||||
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
|
unallocated_req_ids = set(
|
||||||
self.allocated_req_ids -
|
self.num_gpu_computed_tokens.keys() - self.allocated_req_ids - scheduler_output.num_scheduled_tokens.keys()
|
||||||
scheduler_output.num_scheduled_tokens.keys())
|
)
|
||||||
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
|
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", num_tokens, unallocated_req_ids)
|
||||||
num_tokens,
|
|
||||||
unallocated_req_ids)
|
|
||||||
metadata = CPUOffloadingConnectorMetadata(
|
metadata = CPUOffloadingConnectorMetadata(
|
||||||
requests={},
|
requests={},
|
||||||
finished_req_ids=set(self.finished_req_ids),
|
finished_req_ids=set(self.finished_req_ids),
|
||||||
@@ -222,22 +193,22 @@ class CPUOffloadingConnectorScheduler:
|
|||||||
metadata.requests[req_id] = ReqMeta(
|
metadata.requests[req_id] = ReqMeta(
|
||||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||||
num_scheduled_tokens=scheduler_output.
|
num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id],
|
||||||
num_scheduled_tokens[req_id],
|
|
||||||
num_computed_tokens=req.num_computed_tokens,
|
num_computed_tokens=req.num_computed_tokens,
|
||||||
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
|
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
|
||||||
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
|
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id],
|
||||||
|
)
|
||||||
|
|
||||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||||
gpu_block_ids = cached_reqs.new_block_ids[idx]
|
gpu_block_ids = cached_reqs.new_block_ids[idx]
|
||||||
metadata.requests[req_id] = ReqMeta(
|
metadata.requests[req_id] = ReqMeta(
|
||||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||||
num_scheduled_tokens=scheduler_output.
|
num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id],
|
||||||
num_scheduled_tokens[req_id],
|
|
||||||
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||||
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||||
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
|
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||||
|
)
|
||||||
self.num_gpu_computed_tokens.clear()
|
self.num_gpu_computed_tokens.clear()
|
||||||
self.num_cpu_computed_tokens.clear()
|
self.num_cpu_computed_tokens.clear()
|
||||||
self.allocated_req_ids.clear()
|
self.allocated_req_ids.clear()
|
||||||
@@ -249,12 +220,10 @@ class CPUOffloadingConnectorScheduler:
|
|||||||
request.get_hash_new_full_blocks = None
|
request.get_hash_new_full_blocks = None
|
||||||
self.finished_req_ids.append(request.request_id)
|
self.finished_req_ids.append(request.request_id)
|
||||||
# inform metadata server to record request, and free it after finish sending
|
# inform metadata server to record request, and free it after finish sending
|
||||||
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
|
self.zmq_rpc_client.call("record_request_cache_and_free_slots", request)
|
||||||
request)
|
|
||||||
|
|
||||||
|
|
||||||
class CPUOffloadingConnectorWorker:
|
class CPUOffloadingConnectorWorker:
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
logger.info("init CPUOffloadingConnectorWorker")
|
logger.info("init CPUOffloadingConnectorWorker")
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
@@ -289,7 +258,7 @@ class CPUOffloadingConnectorWorker:
|
|||||||
def init_metadata_server(self, vllm_config: VllmConfig):
|
def init_metadata_server(self, vllm_config: VllmConfig):
|
||||||
self.metadata_thread = threading.Thread(
|
self.metadata_thread = threading.Thread(
|
||||||
target=MetadataServerProc.run_metadata_server,
|
target=MetadataServerProc.run_metadata_server,
|
||||||
args=(vllm_config, ),
|
args=(vllm_config,),
|
||||||
)
|
)
|
||||||
self.metadata_thread.daemon = True
|
self.metadata_thread.daemon = True
|
||||||
self.metadata_thread.start()
|
self.metadata_thread.start()
|
||||||
@@ -304,18 +273,15 @@ class CPUOffloadingConnectorWorker:
|
|||||||
logger.info(f"wait for metadata server to start, error: {e}")
|
logger.info(f"wait for metadata server to start, error: {e}")
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def bind_connector_metadata(
|
def bind_connector_metadata(self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
|
||||||
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
|
|
||||||
for req_id, req in connector_metadata.requests.items():
|
for req_id, req in connector_metadata.requests.items():
|
||||||
if req_id in self.requests:
|
if req_id in self.requests:
|
||||||
self.requests[req_id].update(req)
|
self.requests[req_id].update(req)
|
||||||
req = self.requests[req_id]
|
req = self.requests[req_id]
|
||||||
else:
|
else:
|
||||||
self.requests[req_id] = req
|
self.requests[req_id] = req
|
||||||
for i in range(req.num_gpu_computed_tokens // self.block_size,
|
for i in range(req.num_gpu_computed_tokens // self.block_size, req.num_computed_tokens // self.block_size):
|
||||||
req.num_computed_tokens // self.block_size):
|
self.load_block_mapping.append((req.cpu_block_ids[i], req.gpu_block_ids[i]))
|
||||||
self.load_block_mapping.append(
|
|
||||||
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
|
|
||||||
for req_id in connector_metadata.finished_req_ids:
|
for req_id in connector_metadata.finished_req_ids:
|
||||||
if req_id in self.requests:
|
if req_id in self.requests:
|
||||||
self.save_input_queue.put((req_id, self.requests[req_id]))
|
self.save_input_queue.put((req_id, self.requests[req_id]))
|
||||||
@@ -326,11 +292,11 @@ class CPUOffloadingConnectorWorker:
|
|||||||
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
|
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
|
||||||
self.gpu_kv_caches = kv_caches
|
self.gpu_kv_caches = kv_caches
|
||||||
model_config = self.vllm_config.model_config
|
model_config = self.vllm_config.model_config
|
||||||
mla_config: Optional[MLAConfig] = None
|
mla_config: MLAConfig | None = None
|
||||||
if model_config.use_mla:
|
if model_config.use_mla:
|
||||||
mla_config = MLAConfig(
|
mla_config = MLAConfig(
|
||||||
model_config.hf_text_config.kv_lora_rank,
|
model_config.hf_text_config.kv_lora_rank, model_config.hf_text_config.qk_rope_head_dim
|
||||||
model_config.hf_text_config.qk_rope_head_dim)
|
)
|
||||||
self.cpu_kv_caches = list(
|
self.cpu_kv_caches = list(
|
||||||
self.zmq_rpc_client.call(
|
self.zmq_rpc_client.call(
|
||||||
"init_cpu_kv_caches",
|
"init_cpu_kv_caches",
|
||||||
@@ -338,7 +304,8 @@ class CPUOffloadingConnectorWorker:
|
|||||||
self.tp_rank,
|
self.tp_rank,
|
||||||
get_kv_cache_spec(self.vllm_config),
|
get_kv_cache_spec(self.vllm_config),
|
||||||
mla_config,
|
mla_config,
|
||||||
).values())
|
).values()
|
||||||
|
)
|
||||||
|
|
||||||
def start_load_kv(self) -> None:
|
def start_load_kv(self) -> None:
|
||||||
self.current_layer = 0
|
self.current_layer = 0
|
||||||
@@ -358,10 +325,8 @@ class CPUOffloadingConnectorWorker:
|
|||||||
cpu_kv_caches = self.cpu_kv_caches[layer]
|
cpu_kv_caches = self.cpu_kv_caches[layer]
|
||||||
with torch.npu.stream(self.load_stream):
|
with torch.npu.stream(self.load_stream):
|
||||||
for cpu_block_id, gpu_block_id in self.load_block_mapping:
|
for cpu_block_id, gpu_block_id in self.load_block_mapping:
|
||||||
for gpu_layer_part, cpu_layer_part in zip(
|
for gpu_layer_part, cpu_layer_part in zip(gpu_kv_caches, cpu_kv_caches):
|
||||||
gpu_kv_caches, cpu_kv_caches):
|
gpu_layer_part[gpu_block_id].copy_(cpu_layer_part[cpu_block_id], non_blocking=True)
|
||||||
gpu_layer_part[gpu_block_id].copy_(
|
|
||||||
cpu_layer_part[cpu_block_id], non_blocking=True)
|
|
||||||
|
|
||||||
def get_finished(self) -> set[str]:
|
def get_finished(self) -> set[str]:
|
||||||
done_sending: set[str] = set()
|
done_sending: set[str] = set()
|
||||||
@@ -380,8 +345,7 @@ class CPUOffloadingConnectorWorker:
|
|||||||
self.done_sending_count[req_id] += 1
|
self.done_sending_count[req_id] += 1
|
||||||
other_ranks_finished_ids: list[str] = []
|
other_ranks_finished_ids: list[str] = []
|
||||||
for i in range(1, self.tp_world_size):
|
for i in range(1, self.tp_world_size):
|
||||||
other_ranks_finished_ids.extend(
|
other_ranks_finished_ids.extend(self.tp_group.recv_object(src=i))
|
||||||
self.tp_group.recv_object(src=i))
|
|
||||||
for req_id in other_ranks_finished_ids:
|
for req_id in other_ranks_finished_ids:
|
||||||
self.done_sending_count[req_id] += 1
|
self.done_sending_count[req_id] += 1
|
||||||
all_done_sending: set[str] = set()
|
all_done_sending: set[str] = set()
|
||||||
@@ -391,8 +355,7 @@ class CPUOffloadingConnectorWorker:
|
|||||||
all_done_sending.add(req_id)
|
all_done_sending.add(req_id)
|
||||||
# release cpu_kv_cache after request sending finished
|
# release cpu_kv_cache after request sending finished
|
||||||
# to avoid rpc blocking, use thread to call rpc asynchronously
|
# to avoid rpc blocking, use thread to call rpc asynchronously
|
||||||
sending_finished_thread = threading.Thread(
|
sending_finished_thread = threading.Thread(target=self._sending_finished, args=(all_done_sending,))
|
||||||
target=self._sending_finished, args=(all_done_sending, ))
|
|
||||||
sending_finished_thread.daemon = True
|
sending_finished_thread.daemon = True
|
||||||
sending_finished_thread.start()
|
sending_finished_thread.start()
|
||||||
|
|
||||||
@@ -411,11 +374,10 @@ class CPUOffloadingConnectorWorker:
|
|||||||
while True:
|
while True:
|
||||||
req_id, req = self.save_input_queue.get()
|
req_id, req = self.save_input_queue.get()
|
||||||
for i in range(
|
for i in range(
|
||||||
req.num_cpu_computed_tokens // self.block_size,
|
req.num_cpu_computed_tokens // self.block_size,
|
||||||
min((req.num_computed_tokens + req.num_scheduled_tokens) //
|
min((req.num_computed_tokens + req.num_scheduled_tokens) // self.block_size, len(req.cpu_block_ids)),
|
||||||
self.block_size, len(req.cpu_block_ids))):
|
):
|
||||||
save_block_mapping.append(
|
save_block_mapping.append((req.gpu_block_ids[i], req.cpu_block_ids[i]))
|
||||||
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
|
|
||||||
with torch.npu.stream(self.save_stream):
|
with torch.npu.stream(self.save_stream):
|
||||||
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
|
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
|
||||||
# non-MLA: kv_layer is list[tensor], typically means [k, v].
|
# non-MLA: kv_layer is list[tensor], typically means [k, v].
|
||||||
@@ -425,13 +387,9 @@ class CPUOffloadingConnectorWorker:
|
|||||||
start, step = 0, 1
|
start, step = 0, 1
|
||||||
for i in range(start, len(save_block_mapping), step):
|
for i in range(start, len(save_block_mapping), step):
|
||||||
gpu_block_id, cpu_block_id = save_block_mapping[i]
|
gpu_block_id, cpu_block_id = save_block_mapping[i]
|
||||||
for cpu_kv_caches, gpu_kv_caches in zip(
|
for cpu_kv_caches, gpu_kv_caches in zip(self.cpu_kv_caches, self.gpu_kv_caches.values()):
|
||||||
self.cpu_kv_caches, self.gpu_kv_caches.values()):
|
for cpu_layer_part, gpu_layer_part in zip(cpu_kv_caches, gpu_kv_caches):
|
||||||
for cpu_layer_part, gpu_layer_part in zip(
|
cpu_layer_part[cpu_block_id].copy_(gpu_layer_part[gpu_block_id], non_blocking=True)
|
||||||
cpu_kv_caches, gpu_kv_caches):
|
|
||||||
cpu_layer_part[cpu_block_id].copy_(
|
|
||||||
gpu_layer_part[gpu_block_id],
|
|
||||||
non_blocking=True)
|
|
||||||
self.save_stream.synchronize()
|
self.save_stream.synchronize()
|
||||||
self.save_output_queue.put(req_id)
|
self.save_output_queue.put(req_id)
|
||||||
save_block_mapping.clear()
|
save_block_mapping.clear()
|
||||||
@@ -453,8 +411,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
|||||||
if vllm_config.cache_config.cache_dtype == "auto":
|
if vllm_config.cache_config.cache_dtype == "auto":
|
||||||
kv_cache_dtype = vllm_config.model_config.dtype
|
kv_cache_dtype = vllm_config.model_config.dtype
|
||||||
else:
|
else:
|
||||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[vllm_config.cache_config.cache_dtype]
|
||||||
vllm_config.cache_config.cache_dtype]
|
|
||||||
|
|
||||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
|
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
|
||||||
@@ -472,10 +429,8 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
|||||||
# using DSA. Fix the spec in vLLM is the final way.
|
# using DSA. Fix the spec in vLLM is the final way.
|
||||||
block_size = vllm_config.cache_config.block_size
|
block_size = vllm_config.cache_config.block_size
|
||||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||||
block_size=block_size,
|
block_size=block_size, num_kv_heads=1, head_size=attn_module.head_size, dtype=kv_cache_dtype
|
||||||
num_kv_heads=1,
|
)
|
||||||
head_size=attn_module.head_size,
|
|
||||||
dtype=kv_cache_dtype)
|
|
||||||
elif spec := attn_module.get_kv_cache_spec(vllm_config):
|
elif spec := attn_module.get_kv_cache_spec(vllm_config):
|
||||||
kv_cache_spec[layer_name] = spec
|
kv_cache_spec[layer_name] = spec
|
||||||
|
|
||||||
@@ -484,8 +439,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
|||||||
|
|
||||||
if len(mamba_layers) > 0:
|
if len(mamba_layers) > 0:
|
||||||
if vllm_config.cache_config.enable_prefix_caching:
|
if vllm_config.cache_config.enable_prefix_caching:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("Prefix caching is not supported for Mamba yet.")
|
||||||
"Prefix caching is not supported for Mamba yet.")
|
|
||||||
for layer_name, mamba_module in mamba_layers.items():
|
for layer_name, mamba_module in mamba_layers.items():
|
||||||
if spec := mamba_module.get_kv_cache_spec(vllm_config):
|
if spec := mamba_module.get_kv_cache_spec(vllm_config):
|
||||||
kv_cache_spec[layer_name] = spec
|
kv_cache_spec[layer_name] = spec
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing.shared_memory import SharedMemory
|
from multiprocessing.shared_memory import SharedMemory
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@@ -14,8 +15,7 @@ from vllm.utils.network_utils import make_zmq_socket
|
|||||||
from vllm.utils.torch_utils import get_dtype_size
|
from vllm.utils.torch_utils import get_dtype_size
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||||
|
|
||||||
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import \
|
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import CPUKVCacheManager
|
||||||
CPUKVCacheManager
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -30,8 +30,7 @@ def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
|
|||||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||||
return kv_transfer_config
|
return kv_transfer_config
|
||||||
elif kv_transfer_config.kv_connector == "MultiConnector":
|
elif kv_transfer_config.kv_connector == "MultiConnector":
|
||||||
ktcs = kv_transfer_config.kv_connector_extra_config.get(
|
ktcs = kv_transfer_config.kv_connector_extra_config.get("connectors")
|
||||||
"connectors")
|
|
||||||
for ktc in ktcs:
|
for ktc in ktcs:
|
||||||
kv_transfer_config = KVTransferConfig(**ktc)
|
kv_transfer_config = KVTransferConfig(**ktc)
|
||||||
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||||
@@ -44,7 +43,6 @@ class MetadataServer:
|
|||||||
DEFAULT_CPU_SWAP_SPACE_GB = 800
|
DEFAULT_CPU_SWAP_SPACE_GB = 800
|
||||||
|
|
||||||
class ZMQRPCClient:
|
class ZMQRPCClient:
|
||||||
|
|
||||||
def __init__(self, identity=None):
|
def __init__(self, identity=None):
|
||||||
if identity is None:
|
if identity is None:
|
||||||
identity = f"worker-{os.getpid()}-{id(self)}"
|
identity = f"worker-{os.getpid()}-{id(self)}"
|
||||||
@@ -56,7 +54,8 @@ class MetadataServer:
|
|||||||
zmq.DEALER, # type: ignore
|
zmq.DEALER, # type: ignore
|
||||||
bind=False,
|
bind=False,
|
||||||
identity=identity.encode(),
|
identity=identity.encode(),
|
||||||
linger=0)
|
linger=0,
|
||||||
|
)
|
||||||
|
|
||||||
def call(self, func_name: str, *args, **kwargs) -> Any:
|
def call(self, func_name: str, *args, **kwargs) -> Any:
|
||||||
request = (func_name, args, kwargs)
|
request = (func_name, args, kwargs)
|
||||||
@@ -74,11 +73,9 @@ class MetadataServer:
|
|||||||
self.shared_memory_dict = memory_dict
|
self.shared_memory_dict = memory_dict
|
||||||
result = {}
|
result = {}
|
||||||
for key, shm in memory_dict.items():
|
for key, shm in memory_dict.items():
|
||||||
tensor = torch.frombuffer(
|
tensor = torch.frombuffer(shm.buf, dtype=layer_dtype).reshape(layer_size)
|
||||||
shm.buf, dtype=layer_dtype).reshape(layer_size)
|
|
||||||
if mla_config is not None:
|
if mla_config is not None:
|
||||||
tensor = tensor.split(
|
tensor = tensor.split([mla_config.nope_dim, mla_config.rope_dim], dim=-1)
|
||||||
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
|
|
||||||
result[key] = tensor
|
result[key] = tensor
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -86,7 +83,7 @@ class MetadataServer:
|
|||||||
# will be finalized by outer process
|
# will be finalized by outer process
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
self.ctx.term()
|
self.ctx.term()
|
||||||
if hasattr(self, 'shared_memory_dict'):
|
if hasattr(self, "shared_memory_dict"):
|
||||||
for shm in self.shared_memory_dict.values():
|
for shm in self.shared_memory_dict.values():
|
||||||
shm.close()
|
shm.close()
|
||||||
|
|
||||||
@@ -96,7 +93,8 @@ class MetadataServer:
|
|||||||
kv_transfer_config = get_cpu_offload_connector(vllm_config)
|
kv_transfer_config = get_cpu_offload_connector(vllm_config)
|
||||||
assert kv_transfer_config is not None
|
assert kv_transfer_config is not None
|
||||||
available_memory_gb = kv_transfer_config.get_from_extra_config(
|
available_memory_gb = kv_transfer_config.get_from_extra_config(
|
||||||
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
|
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB
|
||||||
|
)
|
||||||
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
|
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
|
||||||
logger.info(f"cpu swap space: {self.available_memory} bytes")
|
logger.info(f"cpu swap space: {self.available_memory} bytes")
|
||||||
self.ctx = zmq.Context() # type: ignore
|
self.ctx = zmq.Context() # type: ignore
|
||||||
@@ -105,7 +103,8 @@ class MetadataServer:
|
|||||||
MetadataServer.METADATA_SERVER_ADDRESS,
|
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||||
zmq.ROUTER, # type: ignore
|
zmq.ROUTER, # type: ignore
|
||||||
bind=True,
|
bind=True,
|
||||||
linger=0)
|
linger=0,
|
||||||
|
)
|
||||||
self.functions: dict[str, Callable] = {
|
self.functions: dict[str, Callable] = {
|
||||||
"init_cpu_kv_caches": self.init_cpu_kv_caches,
|
"init_cpu_kv_caches": self.init_cpu_kv_caches,
|
||||||
"post_init": self.post_init,
|
"post_init": self.post_init,
|
||||||
@@ -133,15 +132,11 @@ class MetadataServer:
|
|||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
kv_cache_specs: dict[str, AttentionSpec],
|
kv_cache_specs: dict[str, AttentionSpec],
|
||||||
mla_config: MLAConfig,
|
mla_config: MLAConfig,
|
||||||
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
|
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype, MLAConfig]:
|
||||||
MLAConfig]:
|
|
||||||
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
|
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
|
||||||
# follow the assumption that each layer has the same spec
|
# follow the assumption that each layer has the same spec
|
||||||
layer = next(iter(kv_cache_specs.values()))
|
layer = next(iter(kv_cache_specs.values()))
|
||||||
assert all([
|
assert all([layer.page_size_bytes == any.page_size_bytes for any in kv_cache_specs.values()])
|
||||||
layer.page_size_bytes == any.page_size_bytes
|
|
||||||
for any in kv_cache_specs.values()
|
|
||||||
])
|
|
||||||
use_mla = isinstance(layer, MLAAttentionSpec)
|
use_mla = isinstance(layer, MLAAttentionSpec)
|
||||||
# mla shares the same kv cache among different tp
|
# mla shares the same kv cache among different tp
|
||||||
if use_mla:
|
if use_mla:
|
||||||
@@ -154,30 +149,24 @@ class MetadataServer:
|
|||||||
available_memory //= self.pipeline_parallel_size
|
available_memory //= self.pipeline_parallel_size
|
||||||
available_memory //= len(kv_cache_specs)
|
available_memory //= len(kv_cache_specs)
|
||||||
num_blocks = available_memory // layer.page_size_bytes
|
num_blocks = available_memory // layer.page_size_bytes
|
||||||
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
|
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore
|
||||||
layer.head_size) # type: ignore
|
|
||||||
else:
|
else:
|
||||||
available_memory //= self.world_size
|
available_memory //= self.world_size
|
||||||
available_memory //= len(kv_cache_specs)
|
available_memory //= len(kv_cache_specs)
|
||||||
num_blocks = available_memory // layer.page_size_bytes
|
num_blocks = available_memory // layer.page_size_bytes
|
||||||
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
|
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore
|
||||||
layer.head_size) # type: ignore
|
|
||||||
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
|
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
|
||||||
for layer_name in kv_cache_specs.keys():
|
for layer_name in kv_cache_specs:
|
||||||
# only this format can share during ZeroMQ+pickle
|
# only this format can share during ZeroMQ+pickle
|
||||||
shared_memory_dict[
|
shared_memory_dict[layer_name] = MetadataServer._safe_create_shared_memory(
|
||||||
layer_name] = MetadataServer._safe_create_shared_memory(
|
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes
|
||||||
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
|
)
|
||||||
if use_mla:
|
if use_mla:
|
||||||
assert mla_config is not None
|
assert mla_config is not None
|
||||||
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
|
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
|
||||||
self.shared_memory[(pp_rank,
|
self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, mla_config)
|
||||||
tp_rank)] = (shared_memory_dict, layer_size,
|
|
||||||
layer.dtype, mla_config)
|
|
||||||
else:
|
else:
|
||||||
self.shared_memory[(pp_rank,
|
self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, None)
|
||||||
tp_rank)] = (shared_memory_dict, layer_size,
|
|
||||||
layer.dtype, None)
|
|
||||||
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
|
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
|
||||||
self.num_cpu_blocks = num_blocks
|
self.num_cpu_blocks = num_blocks
|
||||||
self.layer = layer
|
self.layer = layer
|
||||||
@@ -185,23 +174,20 @@ class MetadataServer:
|
|||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
# different processors in data parallel may call multiple times
|
# different processors in data parallel may call multiple times
|
||||||
if hasattr(self, 'cpu_block_manager'):
|
if hasattr(self, "cpu_block_manager"):
|
||||||
return
|
return
|
||||||
# do shared_memory() at least once
|
# do shared_memory() at least once
|
||||||
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
|
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
|
||||||
assert self.num_cpu_blocks >= 0
|
assert self.num_cpu_blocks >= 0
|
||||||
self.cpu_block_manager = CPUKVCacheManager(self.layer,
|
self.cpu_block_manager = CPUKVCacheManager(self.layer, self.num_cpu_blocks)
|
||||||
self.num_cpu_blocks)
|
self.functions.update(
|
||||||
self.functions.update({
|
{
|
||||||
"get_matched_num_and_touch":
|
"get_matched_num_and_touch": self.cpu_block_manager.get_matched_num_and_touch,
|
||||||
self.cpu_block_manager.get_matched_num_and_touch,
|
"allocate_slots": self.cpu_block_manager.allocate_slots,
|
||||||
"allocate_slots":
|
"record_request_cache_and_free_slots": self.cpu_block_manager.record_request_cache_and_free_slots,
|
||||||
self.cpu_block_manager.allocate_slots,
|
"cache_and_free_slots": self.cpu_block_manager.cache_and_free_slots,
|
||||||
"record_request_cache_and_free_slots":
|
}
|
||||||
self.cpu_block_manager.record_request_cache_and_free_slots,
|
)
|
||||||
"cache_and_free_slots":
|
|
||||||
self.cpu_block_manager.cache_and_free_slots,
|
|
||||||
})
|
|
||||||
|
|
||||||
def serve_step(self):
|
def serve_step(self):
|
||||||
client_id = self.socket.recv()
|
client_id = self.socket.recv()
|
||||||
@@ -228,8 +214,7 @@ class MetadataServer:
|
|||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
self.ctx.term()
|
self.ctx.term()
|
||||||
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
|
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace("ipc://", "")
|
||||||
"ipc://", "")
|
|
||||||
if os.path.exists(socket_path):
|
if os.path.exists(socket_path):
|
||||||
os.remove(socket_path)
|
os.remove(socket_path)
|
||||||
for cached in self.shared_memory.values():
|
for cached in self.shared_memory.values():
|
||||||
@@ -239,11 +224,9 @@ class MetadataServer:
|
|||||||
|
|
||||||
|
|
||||||
class MetadataServerProc:
|
class MetadataServerProc:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def run_metadata_server(vllm_config: VllmConfig):
|
def run_metadata_server(vllm_config: VllmConfig):
|
||||||
if (not vllm_config.cache_config.enable_prefix_caching
|
if not vllm_config.cache_config.enable_prefix_caching or get_cpu_offload_connector(vllm_config) is None:
|
||||||
or get_cpu_offload_connector(vllm_config) is None):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
shutdown_requested = False
|
shutdown_requested = False
|
||||||
@@ -257,7 +240,7 @@ class MetadataServerProc:
|
|||||||
# Either SIGTERM or SIGINT will terminate the worker
|
# Either SIGTERM or SIGINT will terminate the worker
|
||||||
# signal.signal(signal.SIGTERM, _signal_handler)
|
# signal.signal(signal.SIGTERM, _signal_handler)
|
||||||
# signal.signal(signal.SIGINT, _signal_handler)
|
# signal.signal(signal.SIGINT, _signal_handler)
|
||||||
metadata_server: Optional[MetadataServer] = None
|
metadata_server: MetadataServer | None = None
|
||||||
try:
|
try:
|
||||||
metadata_server = MetadataServer(vllm_config)
|
metadata_server = MetadataServer(vllm_config)
|
||||||
logger.info("Metadata server started.")
|
logger.info("Metadata server started.")
|
||||||
|
|||||||
@@ -4,19 +4,21 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
import torch
|
import torch
|
||||||
from ucm.integration.vllm.ucm_connector import UCMConnector
|
from ucm.integration.vllm.ucm_connector import UCMConnector
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
|
||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# isort: off
|
# isort: off
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||||
KVConnectorPromMetrics, KVConnectorStats, PromMetric, PromMetricT)
|
KVConnectorPromMetrics,
|
||||||
|
KVConnectorStats,
|
||||||
|
PromMetric,
|
||||||
|
PromMetricT,
|
||||||
|
)
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
@@ -25,16 +27,13 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class UCMConnectorV1(KVConnectorBase_V1):
|
class UCMConnectorV1(KVConnectorBase_V1):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
role: KVConnectorRole,
|
role: KVConnectorRole,
|
||||||
kv_cache_config: "KVCacheConfig",
|
kv_cache_config: "KVCacheConfig",
|
||||||
):
|
):
|
||||||
super().__init__(vllm_config=vllm_config,
|
super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
|
||||||
role=role,
|
|
||||||
kv_cache_config=kv_cache_config)
|
|
||||||
assert vllm_config.kv_transfer_config is not None
|
assert vllm_config.kv_transfer_config is not None
|
||||||
|
|
||||||
ImplCls = UCMConnector
|
ImplCls = UCMConnector
|
||||||
@@ -60,8 +59,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
|
|||||||
"""
|
"""
|
||||||
self._ucm_engine.register_kv_caches(kv_caches)
|
self._ucm_engine.register_kv_caches(kv_caches)
|
||||||
|
|
||||||
def start_load_kv(self, forward_context: "ForwardContext",
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||||
**kwargs: Any) -> None:
|
|
||||||
"""
|
"""
|
||||||
Start loading the KV cache from the connector to vLLM's paged
|
Start loading the KV cache from the connector to vLLM's paged
|
||||||
KV buffer. This is called from the forward context before the
|
KV buffer. This is called from the forward context before the
|
||||||
@@ -110,8 +108,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
|
|||||||
attn_metadata (AttentionMetadata): the attention metadata.
|
attn_metadata (AttentionMetadata): the attention metadata.
|
||||||
**kwargs: additional arguments for the save operation.
|
**kwargs: additional arguments for the save operation.
|
||||||
"""
|
"""
|
||||||
self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
|
self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def wait_for_save(self) -> None:
|
def wait_for_save(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -131,8 +128,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
|
|||||||
"""
|
"""
|
||||||
self._ucm_engine.clear_connector_metadata()
|
self._ucm_engine.clear_connector_metadata()
|
||||||
|
|
||||||
def bind_connector_metadata(
|
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
|
||||||
"""Set the connector metadata from the scheduler.
|
"""Set the connector metadata from the scheduler.
|
||||||
|
|
||||||
This function should be called by the model runner every time
|
This function should be called by the model runner every time
|
||||||
@@ -175,20 +171,15 @@ class UCMConnectorV1(KVConnectorBase_V1):
|
|||||||
the number of tokens that can be loaded from the
|
the number of tokens that can be loaded from the
|
||||||
external KV cache beyond what is already computed.
|
external KV cache beyond what is already computed.
|
||||||
"""
|
"""
|
||||||
return self._ucm_engine.get_num_new_matched_tokens(
|
return self._ucm_engine.get_num_new_matched_tokens(request, num_computed_tokens)
|
||||||
request, num_computed_tokens)
|
|
||||||
|
|
||||||
def update_state_after_alloc(self, request: "Request",
|
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int) -> None:
|
||||||
blocks: "KVCacheBlocks",
|
|
||||||
num_external_tokens: int) -> None:
|
|
||||||
"""
|
"""
|
||||||
Update KVConnector state after block allocation.
|
Update KVConnector state after block allocation.
|
||||||
"""
|
"""
|
||||||
self._ucm_engine.update_state_after_alloc(request, blocks,
|
self._ucm_engine.update_state_after_alloc(request, blocks, num_external_tokens)
|
||||||
num_external_tokens)
|
|
||||||
|
|
||||||
def build_connector_meta(
|
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
|
||||||
"""
|
"""
|
||||||
Build the connector metadata for this step.
|
Build the connector metadata for this step.
|
||||||
|
|
||||||
@@ -222,10 +213,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
|
|||||||
# ==============================
|
# ==============================
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_kv_connector_stats(
|
def build_kv_connector_stats(cls, data: dict[str, Any] | None = None) -> Optional["KVConnectorStats"]:
|
||||||
cls,
|
|
||||||
data: dict[str, Any] | None = None
|
|
||||||
) -> Optional["KVConnectorStats"]:
|
|
||||||
"""
|
"""
|
||||||
KVConnectorStats resolution method. This method allows dynamically
|
KVConnectorStats resolution method. This method allows dynamically
|
||||||
registered connectors to return their own KVConnectorStats object,
|
registered connectors to return their own KVConnectorStats object,
|
||||||
|
|||||||
@@ -1,19 +1,16 @@
|
|||||||
import ipaddress
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from mooncake.engine import TransferEngine # type: ignore
|
from mooncake.engine import TransferEngine # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class GlobalTE():
|
class GlobalTE:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.transfer_engine = None
|
self.transfer_engine = None
|
||||||
self.is_register_buffer: bool = False
|
self.is_register_buffer: bool = False
|
||||||
self.transfer_engine_lock = threading.Lock()
|
self.transfer_engine_lock = threading.Lock()
|
||||||
self.register_buffer_lock = threading.Lock()
|
self.register_buffer_lock = threading.Lock()
|
||||||
|
|
||||||
def get_transfer_engine(self, hostname: str, device_name: Optional[str]):
|
def get_transfer_engine(self, hostname: str, device_name: str | None):
|
||||||
if self.transfer_engine is None:
|
if self.transfer_engine is None:
|
||||||
with self.transfer_engine_lock:
|
with self.transfer_engine_lock:
|
||||||
# Double-Checked Locking
|
# Double-Checked Locking
|
||||||
@@ -22,12 +19,9 @@ class GlobalTE():
|
|||||||
raise RuntimeError("mooncake is not available")
|
raise RuntimeError("mooncake is not available")
|
||||||
self.transfer_engine = TransferEngine()
|
self.transfer_engine = TransferEngine()
|
||||||
device_name = device_name if device_name is not None else ""
|
device_name = device_name if device_name is not None else ""
|
||||||
ret_value = self.transfer_engine.initialize(
|
ret_value = self.transfer_engine.initialize(hostname, "P2PHANDSHAKE", "ascend", device_name)
|
||||||
hostname, "P2PHANDSHAKE", "ascend", device_name)
|
|
||||||
if ret_value != 0:
|
if ret_value != 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"TransferEngine initialization failed with ret_value: {ret_value}")
|
||||||
f"TransferEngine initialization failed with ret_value: {ret_value}"
|
|
||||||
)
|
|
||||||
return self.transfer_engine
|
return self.transfer_engine
|
||||||
|
|
||||||
def register_buffer(self, ptrs: list[int], sizes: list[int]):
|
def register_buffer(self, ptrs: list[int], sizes: list[int]):
|
||||||
|
|||||||
@@ -6,8 +6,7 @@ import torch.distributed as dist
|
|||||||
from vllm_ascend.distributed.parallel_state import get_p_tp_group
|
from vllm_ascend.distributed.parallel_state import get_p_tp_group
|
||||||
|
|
||||||
|
|
||||||
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
|
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, value: torch.TensorType):
|
||||||
value: torch.TensorType):
|
|
||||||
if pd_tp_ratio <= 1:
|
if pd_tp_ratio <= 1:
|
||||||
return None, None
|
return None, None
|
||||||
elif key is None or value is None:
|
elif key is None or value is None:
|
||||||
@@ -20,22 +19,17 @@ def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
|
|||||||
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
|
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
|
||||||
num_kv_heads = input_tensor.size(1)
|
num_kv_heads = input_tensor.size(1)
|
||||||
output_tensor = torch.zeros_like(input_tensor)
|
output_tensor = torch.zeros_like(input_tensor)
|
||||||
dist.all_to_all_single(output_tensor,
|
dist.all_to_all_single(output_tensor, input_tensor, group=get_p_tp_group().device_group)
|
||||||
input_tensor,
|
|
||||||
group=get_p_tp_group().device_group)
|
|
||||||
input_tensor = 0
|
input_tensor = 0
|
||||||
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
|
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
|
||||||
output_tensor = 0
|
output_tensor = 0
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def rearrange_output(base_output: torch.Tensor, cut_num: int,
|
def rearrange_output(base_output: torch.Tensor, cut_num: int, num_kv_heads: int):
|
||||||
num_kv_heads: int):
|
|
||||||
size_0 = base_output.size(0)
|
size_0 = base_output.size(0)
|
||||||
if size_0 % cut_num != 0:
|
if size_0 % cut_num != 0:
|
||||||
raise ValueError(
|
raise ValueError(f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]")
|
||||||
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
|
|
||||||
)
|
|
||||||
chunk_size = size_0 // cut_num
|
chunk_size = size_0 // cut_num
|
||||||
reshaped = base_output.view(cut_num, chunk_size, -1)
|
reshaped = base_output.view(cut_num, chunk_size, -1)
|
||||||
transposed = reshaped.transpose(0, 1)
|
transposed = reshaped.transpose(0, 1)
|
||||||
@@ -46,16 +40,13 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
|||||||
data_ptr = tensor.data_ptr()
|
data_ptr = tensor.data_ptr()
|
||||||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||||
return tensor[int(offset):]
|
return tensor[int(offset) :]
|
||||||
|
|
||||||
|
|
||||||
def get_transfer_timeout_value():
|
def get_transfer_timeout_value():
|
||||||
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
|
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
|
||||||
if len(ascend_transfer_timeout) > 0:
|
if len(ascend_transfer_timeout) > 0:
|
||||||
return int(ascend_transfer_timeout)
|
return int(ascend_transfer_timeout)
|
||||||
hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT',
|
hccl_rdma_timeout = int(os.getenv("HCCL_RDMA_TIMEOUT", "20")) # type: ignore
|
||||||
'20')) # type: ignore
|
hccl_rdma_retry_cnt = int(os.getenv("HCCL_RDMA_RETRY_CNT", "7")) # type: ignore
|
||||||
hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT',
|
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000)
|
||||||
'7')) # type: ignore
|
|
||||||
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
|
|
||||||
3000)
|
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
from vllm.v1.attention.backend import AttentionBackend # type: ignore
|
from vllm.v1.attention.backend import AttentionBackend # type: ignore
|
||||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
||||||
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
|
from vllm.v1.kv_offload.worker.worker import OffloadingHandler, TransferResult, TransferSpec
|
||||||
TransferResult, TransferSpec)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -44,7 +43,6 @@ def expand_block_ids(
|
|||||||
|
|
||||||
|
|
||||||
class CpuNpuOffloadingHandler(OffloadingHandler):
|
class CpuNpuOffloadingHandler(OffloadingHandler):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gpu_block_size: int,
|
gpu_block_size: int,
|
||||||
@@ -81,20 +79,22 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
|
|||||||
cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor
|
cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor
|
||||||
|
|
||||||
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
|
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
|
||||||
self.cpu_tensors.append((
|
self.cpu_tensors.append(
|
||||||
torch.zeros(
|
(
|
||||||
cpu_shape,
|
torch.zeros(
|
||||||
dtype=gpu_tensor[0].dtype,
|
cpu_shape,
|
||||||
device="cpu",
|
dtype=gpu_tensor[0].dtype,
|
||||||
pin_memory=pin_memory,
|
device="cpu",
|
||||||
),
|
pin_memory=pin_memory,
|
||||||
torch.zeros(
|
),
|
||||||
cpu_shape,
|
torch.zeros(
|
||||||
dtype=gpu_tensor[0].dtype,
|
cpu_shape,
|
||||||
device="cpu",
|
dtype=gpu_tensor[0].dtype,
|
||||||
pin_memory=pin_memory,
|
device="cpu",
|
||||||
),
|
pin_memory=pin_memory,
|
||||||
))
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||||
logger.info("start transfer_async...")
|
logger.info("start transfer_async...")
|
||||||
@@ -123,9 +123,7 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
|
|||||||
dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor
|
dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor
|
||||||
src_sub_block_count = src_blocks.size * src_block_size_factor
|
src_sub_block_count = src_blocks.size * src_block_size_factor
|
||||||
|
|
||||||
assert (
|
assert src_sub_block_count == dst_blocks.size * dst_block_size_factor - dst_sub_blocks_to_skip
|
||||||
src_sub_block_count == dst_blocks.size * dst_block_size_factor -
|
|
||||||
dst_sub_blocks_to_skip)
|
|
||||||
|
|
||||||
src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64)
|
src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64)
|
||||||
expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0])
|
expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0])
|
||||||
@@ -137,18 +135,14 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
|
|||||||
)
|
)
|
||||||
src_to_dst_tensor = torch.from_numpy(src_to_dst)
|
src_to_dst_tensor = torch.from_numpy(src_to_dst)
|
||||||
|
|
||||||
event = self.events_pool.pop(
|
event = self.events_pool.pop() if self.events_pool else torch.npu.Event()
|
||||||
) if self.events_pool else torch.npu.Event()
|
|
||||||
with torch.npu.stream(stream):
|
with torch.npu.stream(stream):
|
||||||
for src_tensor, dst_tensor in zip(src_tensors, dst_tensors):
|
for src_tensor, dst_tensor in zip(src_tensors, dst_tensors):
|
||||||
src_key_cache, src_value_cache = src_tensor[0], src_tensor[1]
|
src_key_cache, src_value_cache = src_tensor[0], src_tensor[1]
|
||||||
dst_key_cache, dst_value_cache = dst_tensor[0], dst_tensor[1]
|
dst_key_cache, dst_value_cache = dst_tensor[0], dst_tensor[1]
|
||||||
|
|
||||||
torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache,
|
torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
|
||||||
src_to_dst_tensor)
|
torch.ops._C_ascend.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
|
||||||
torch.ops._C_ascend.swap_blocks(src_value_cache,
|
|
||||||
dst_value_cache,
|
|
||||||
src_to_dst_tensor)
|
|
||||||
|
|
||||||
event.record(stream)
|
event.record(stream)
|
||||||
|
|
||||||
|
|||||||
@@ -1,48 +1,40 @@
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.attention.backend import AttentionBackend # type: ignore
|
from vllm.v1.attention.backend import AttentionBackend # type: ignore
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
||||||
from vllm.v1.kv_offload.backends.cpu import CPUBackend
|
from vllm.v1.kv_offload.backends.cpu import CPUBackend
|
||||||
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
|
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
|
||||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
||||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||||
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
||||||
|
|
||||||
from vllm_ascend.kv_offload.cpu_npu import CpuNpuOffloadingHandler
|
from vllm_ascend.kv_offload.cpu_npu import CpuNpuOffloadingHandler
|
||||||
|
|
||||||
|
|
||||||
class NPUOffloadingSpec(OffloadingSpec):
|
class NPUOffloadingSpec(OffloadingSpec):
|
||||||
|
def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig | None = None):
|
||||||
def __init__(self,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
|
||||||
super().__init__(vllm_config, kv_cache_config)
|
super().__init__(vllm_config, kv_cache_config)
|
||||||
|
|
||||||
num_cpu_blocks = self.extra_config.get("num_cpu_blocks")
|
num_cpu_blocks = self.extra_config.get("num_cpu_blocks")
|
||||||
if not num_cpu_blocks:
|
if not num_cpu_blocks:
|
||||||
raise Exception(
|
raise Exception("num_cpu_blocks must be specified in kv_connector_extra_config")
|
||||||
"num_cpu_blocks must be specified in kv_connector_extra_config"
|
|
||||||
)
|
|
||||||
self.num_cpu_blocks: int = num_cpu_blocks
|
self.num_cpu_blocks: int = num_cpu_blocks
|
||||||
|
|
||||||
# scheduler-side
|
# scheduler-side
|
||||||
self._manager: Optional[OffloadingManager] = None
|
self._manager: OffloadingManager | None = None
|
||||||
|
|
||||||
# worker-side
|
# worker-side
|
||||||
self._handler: Optional[OffloadingHandler] = None
|
self._handler: OffloadingHandler | None = None
|
||||||
|
|
||||||
def get_manager(self) -> OffloadingManager:
|
def get_manager(self) -> OffloadingManager:
|
||||||
if not self._manager:
|
if not self._manager:
|
||||||
kv_events_config = self.vllm_config.kv_events_config
|
kv_events_config = self.vllm_config.kv_events_config
|
||||||
enable_events = (kv_events_config is not None
|
enable_events = kv_events_config is not None and kv_events_config.enable_kv_cache_events
|
||||||
and kv_events_config.enable_kv_cache_events)
|
|
||||||
self._manager = LRUOffloadingManager(
|
self._manager = LRUOffloadingManager(
|
||||||
CPUBackend(block_size=self.offloaded_block_size,
|
CPUBackend(block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks),
|
||||||
num_blocks=self.num_cpu_blocks),
|
|
||||||
enable_events=enable_events,
|
enable_events=enable_events,
|
||||||
)
|
)
|
||||||
return self._manager
|
return self._manager
|
||||||
@@ -51,8 +43,7 @@ class NPUOffloadingSpec(OffloadingSpec):
|
|||||||
self,
|
self,
|
||||||
kv_caches: dict[str, torch.Tensor],
|
kv_caches: dict[str, torch.Tensor],
|
||||||
attn_backends: dict[str, type[AttentionBackend]],
|
attn_backends: dict[str, type[AttentionBackend]],
|
||||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
|
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||||
OffloadingHandler]]:
|
|
||||||
if not self._handler:
|
if not self._handler:
|
||||||
self._handler = CpuNpuOffloadingHandler(
|
self._handler = CpuNpuOffloadingHandler(
|
||||||
attn_backends=attn_backends,
|
attn_backends=attn_backends,
|
||||||
|
|||||||
@@ -16,11 +16,13 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def bgmv_shrink(inputs: torch.Tensor,
|
def bgmv_shrink(
|
||||||
lora_a_weights: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
lora_a_weights: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
scaling: float = 1.0):
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
scaling: float = 1.0,
|
||||||
|
):
|
||||||
return torch.ops._C_ascend.bgmv_shrink(
|
return torch.ops._C_ascend.bgmv_shrink(
|
||||||
inputs,
|
inputs,
|
||||||
lora_a_weights,
|
lora_a_weights,
|
||||||
@@ -30,11 +32,13 @@ def bgmv_shrink(inputs: torch.Tensor,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def bgmv_expand(inputs: torch.Tensor,
|
def bgmv_expand(
|
||||||
lora_b_weights: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
add_inputs: bool = True):
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
):
|
||||||
return torch.ops._C_ascend.bgmv_expand(
|
return torch.ops._C_ascend.bgmv_expand(
|
||||||
inputs,
|
inputs,
|
||||||
lora_b_weights,
|
lora_b_weights,
|
||||||
@@ -45,16 +49,18 @@ def bgmv_expand(inputs: torch.Tensor,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def bgmv_expand_slice(inputs: torch.Tensor,
|
def bgmv_expand_slice(
|
||||||
lora_b_weights: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
slice_offset: int,
|
lora_indices_tensor: torch.Tensor,
|
||||||
slice_size: int,
|
slice_offset: int,
|
||||||
add_inputs: bool = True):
|
slice_size: int,
|
||||||
return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights,
|
add_inputs: bool = True,
|
||||||
lora_indices_tensor, output_tensor,
|
):
|
||||||
slice_offset, slice_size)
|
return torch.ops._C_ascend.bgmv_expand(
|
||||||
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset, slice_size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sgmv_shrink(
|
def sgmv_shrink(
|
||||||
@@ -69,21 +75,23 @@ def sgmv_shrink(
|
|||||||
token_nums: int,
|
token_nums: int,
|
||||||
scaling: float,
|
scaling: float,
|
||||||
):
|
):
|
||||||
return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights,
|
return torch.ops._C_ascend.sgmv_shrink(
|
||||||
lora_indices_tensor, seq_len_tensor,
|
inputs, lora_a_weights, lora_indices_tensor, seq_len_tensor, output_tensor, scaling
|
||||||
output_tensor, scaling)
|
)
|
||||||
|
|
||||||
|
|
||||||
def sgmv_expand(inputs: torch.Tensor,
|
def sgmv_expand(
|
||||||
lora_b_weights: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
b_seq_start_loc: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
seq_len_tensor: torch.Tensor,
|
b_seq_start_loc: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
seq_len_tensor: torch.Tensor,
|
||||||
batches: int,
|
lora_indices_tensor: torch.Tensor,
|
||||||
max_seq_length: int,
|
batches: int,
|
||||||
token_nums: int,
|
max_seq_length: int,
|
||||||
add_inputs: bool = False):
|
token_nums: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
):
|
||||||
return torch.ops._C_ascend.sgmv_expand(
|
return torch.ops._C_ascend.sgmv_expand(
|
||||||
inputs,
|
inputs,
|
||||||
lora_b_weights,
|
lora_b_weights,
|
||||||
@@ -95,19 +103,20 @@ def sgmv_expand(inputs: torch.Tensor,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
def sgmv_expand_slice(
|
||||||
lora_b_weights: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
b_seq_start_loc: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
seq_len_tensor: torch.Tensor,
|
b_seq_start_loc: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
seq_len_tensor: torch.Tensor,
|
||||||
batches: int,
|
lora_indices_tensor: torch.Tensor,
|
||||||
max_seq_length: int,
|
batches: int,
|
||||||
token_nums: int,
|
max_seq_length: int,
|
||||||
slice_offset: int,
|
token_nums: int,
|
||||||
slice_size: int,
|
slice_offset: int,
|
||||||
add_inputs: bool = False):
|
slice_size: int,
|
||||||
return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights,
|
add_inputs: bool = False,
|
||||||
lora_indices_tensor, seq_len_tensor,
|
):
|
||||||
output_tensor, slice_offset,
|
return torch.ops._C_ascend.sgmv_expand(
|
||||||
slice_size)
|
inputs, lora_b_weights, lora_indices_tensor, seq_len_tensor, output_tensor, slice_offset, slice_size
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from collections.abc import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||||
@@ -18,26 +18,30 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
def __init__(self, max_num_batched_tokens: int, max_batches: int, device: torch.device | str, **kwargs):
|
||||||
device: Union[torch.device, str], **kwargs):
|
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
|
||||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
|
||||||
device)
|
|
||||||
refresh_all_lora_classes()
|
refresh_all_lora_classes()
|
||||||
self.lora_config = kwargs.get("lora_config")
|
self.lora_config = kwargs.get("lora_config")
|
||||||
if get_ascend_device_type() == AscendDeviceType._310P or (
|
if get_ascend_device_type() == AscendDeviceType._310P or (
|
||||||
self.lora_config is not None
|
self.lora_config is not None and self.lora_config.max_lora_rank >= 128
|
||||||
and self.lora_config.max_lora_rank >= 128):
|
):
|
||||||
from vllm.lora.ops.torch_ops import (bgmv_expand,
|
from vllm.lora.ops.torch_ops import (
|
||||||
bgmv_expand_slice,
|
bgmv_expand,
|
||||||
bgmv_shrink, sgmv_expand,
|
bgmv_expand_slice,
|
||||||
sgmv_expand_slice,
|
bgmv_shrink,
|
||||||
sgmv_shrink)
|
sgmv_expand,
|
||||||
|
sgmv_expand_slice,
|
||||||
|
sgmv_shrink,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from vllm_ascend.lora.lora_ops import (bgmv_expand,
|
from vllm_ascend.lora.lora_ops import (
|
||||||
bgmv_expand_slice,
|
bgmv_expand,
|
||||||
bgmv_shrink, sgmv_expand,
|
bgmv_expand_slice,
|
||||||
sgmv_expand_slice,
|
bgmv_shrink,
|
||||||
sgmv_shrink)
|
sgmv_expand,
|
||||||
|
sgmv_expand_slice,
|
||||||
|
sgmv_shrink,
|
||||||
|
)
|
||||||
self.bgmv_expand = bgmv_expand
|
self.bgmv_expand = bgmv_expand
|
||||||
self.bgmv_expand_slice = bgmv_expand_slice
|
self.bgmv_expand_slice = bgmv_expand_slice
|
||||||
self.bgmv_shrink = bgmv_shrink
|
self.bgmv_shrink = bgmv_shrink
|
||||||
@@ -52,7 +56,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
):
|
):
|
||||||
#No LoRA request, so return directly
|
# No LoRA request, so return directly
|
||||||
if self.no_lora:
|
if self.no_lora:
|
||||||
return
|
return
|
||||||
self.sgmv_shrink(
|
self.sgmv_shrink(
|
||||||
@@ -79,7 +83,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
add_inputs: bool,
|
add_inputs: bool,
|
||||||
):
|
):
|
||||||
#No LoRA request, so return directly
|
# No LoRA request, so return directly
|
||||||
if self.no_lora:
|
if self.no_lora:
|
||||||
return
|
return
|
||||||
self.sgmv_expand(
|
self.sgmv_expand(
|
||||||
@@ -108,7 +112,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
y_slice_size: int,
|
y_slice_size: int,
|
||||||
add_inputs: bool,
|
add_inputs: bool,
|
||||||
):
|
):
|
||||||
#No LoRA request, so return directly
|
# No LoRA request, so return directly
|
||||||
if self.no_lora:
|
if self.no_lora:
|
||||||
return
|
return
|
||||||
self.sgmv_expand_slice(
|
self.sgmv_expand_slice(
|
||||||
@@ -130,8 +134,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
y_slice_size: int,
|
y_slice_size: int,
|
||||||
add_inputs: bool,
|
add_inputs: bool,
|
||||||
):
|
):
|
||||||
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices,
|
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs)
|
||||||
y_offset, y_slice_size, add_inputs)
|
|
||||||
|
|
||||||
def _apply_expand(
|
def _apply_expand(
|
||||||
self,
|
self,
|
||||||
@@ -148,13 +151,10 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
GEMM of lora'b.
|
GEMM of lora'b.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
expand_slice_fun: Callable = (self._expand_slice_prefill
|
expand_slice_fun: Callable = self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
|
||||||
if self.is_prefill else
|
|
||||||
self._expand_slice_decode)
|
|
||||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
||||||
|
|
||||||
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float):
|
||||||
w_t_all: torch.Tensor, scale: float):
|
|
||||||
"""
|
"""
|
||||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||||
GEMM of lora'a.
|
GEMM of lora'a.
|
||||||
@@ -165,14 +165,18 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
"""
|
"""
|
||||||
y_org = y
|
y_org = y
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
shrink_fun: Callable = (self._shrink_prefill
|
shrink_fun: Callable = self._shrink_prefill if self.is_prefill else self._shrink_decode
|
||||||
if self.is_prefill else self._shrink_decode)
|
|
||||||
shrink_fun(y, x, w_t_all, scale)
|
shrink_fun(y, x, w_t_all, scale)
|
||||||
y = y.view_as(y_org)
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
def add_shrink(
|
||||||
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
self,
|
||||||
scale: float, **kwargs):
|
y: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||||
|
scale: float,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Performs GEMM for multiple slices of lora_a.
|
Performs GEMM for multiple slices of lora_a.
|
||||||
When `is_prefill is` true, it indicates that it is currently the
|
When `is_prefill is` true, it indicates that it is currently the
|
||||||
@@ -194,18 +198,19 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
# TODO fuse these kernels
|
# TODO fuse these kernels
|
||||||
for slice_idx in range(len(lora_a_stacked)):
|
for slice_idx in range(len(lora_a_stacked)):
|
||||||
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
|
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
|
||||||
scale)
|
|
||||||
|
|
||||||
def add_expand(self,
|
def add_expand(
|
||||||
y: torch.Tensor,
|
self,
|
||||||
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
y: torch.Tensor,
|
||||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
x: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
output_slices: Tuple[int, ...],
|
lora_bias_stacked: tuple[torch.Tensor, ...] | None,
|
||||||
offset_start: int = 0,
|
output_slices: tuple[int, ...],
|
||||||
add_inputs=True,
|
offset_start: int = 0,
|
||||||
**kwargs) -> None:
|
add_inputs=True,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||||
|
|
||||||
@@ -229,8 +234,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
offset_left = offset_start
|
offset_left = offset_start
|
||||||
if lora_bias_stacked is not None:
|
if lora_bias_stacked is not None:
|
||||||
self._apply_bias(self.token_lora_indices, y, output_slices,
|
self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked)
|
||||||
lora_bias_stacked)
|
|
||||||
for slice_idx in range(len(lora_b_stacked)):
|
for slice_idx in range(len(lora_b_stacked)):
|
||||||
self._apply_expand(
|
self._apply_expand(
|
||||||
y,
|
y,
|
||||||
@@ -243,12 +247,9 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
offset_left += output_slices[slice_idx]
|
offset_left += output_slices[slice_idx]
|
||||||
y = y.view_as(y_org)
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
def add_lora_embedding(self,
|
def add_lora_embedding(
|
||||||
y: torch.Tensor,
|
self, y: torch.Tensor, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs
|
||||||
x: torch.Tensor,
|
) -> None:
|
||||||
lora_b_stacked: torch.Tensor,
|
|
||||||
add_inputs: bool = True,
|
|
||||||
**kwargs) -> None:
|
|
||||||
"""
|
"""
|
||||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||||
|
|
||||||
@@ -263,21 +264,22 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Embedding layer only need expand op
|
# Embedding layer only need expand op
|
||||||
expand_fun: Callable = (self._expand_prefill
|
expand_fun: Callable = self._expand_prefill if self.is_prefill else self._expand_decode
|
||||||
if self.is_prefill else self._expand_decode)
|
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
expand_fun(y, x, lora_b_stacked, add_inputs)
|
expand_fun(y, x, lora_b_stacked, add_inputs)
|
||||||
|
|
||||||
def add_lora_linear(self,
|
def add_lora_linear(
|
||||||
y: torch.Tensor,
|
self,
|
||||||
x: torch.Tensor,
|
y: torch.Tensor,
|
||||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
x: torch.Tensor,
|
||||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||||
scale: float,
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
output_slices: Tuple[int, ...],
|
scale: float,
|
||||||
*,
|
output_slices: tuple[int, ...],
|
||||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
*,
|
||||||
**kwargs) -> None:
|
buffer: tuple[torch.Tensor, ...] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Applicable to linear-related lora.
|
Applicable to linear-related lora.
|
||||||
|
|
||||||
@@ -308,27 +310,22 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
# We set the buffer to be float32 by default, consistent with the
|
# We set the buffer to be float32 by default, consistent with the
|
||||||
# triton op
|
# triton op
|
||||||
buffer = tuple(
|
buffer = tuple(
|
||||||
torch.zeros(
|
torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) for _ in range(len(output_slices))
|
||||||
(x.size(0), r), dtype=torch.float32, device=x.device)
|
)
|
||||||
for _ in range(len(output_slices)))
|
|
||||||
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||||
self.add_expand(y,
|
self.add_expand(y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs)
|
||||||
buffer,
|
|
||||||
lora_b_stacked,
|
|
||||||
None,
|
|
||||||
output_slices,
|
|
||||||
add_inputs=True,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def add_lora_logits(self,
|
def add_lora_logits(
|
||||||
y: torch.Tensor,
|
self,
|
||||||
x: torch.Tensor,
|
y: torch.Tensor,
|
||||||
lora_a_stacked: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_b_stacked: torch.Tensor,
|
lora_a_stacked: torch.Tensor,
|
||||||
scale,
|
lora_b_stacked: torch.Tensor,
|
||||||
*,
|
scale,
|
||||||
buffer: Optional[torch.Tensor] = None,
|
*,
|
||||||
**kwargs) -> None:
|
buffer: torch.Tensor | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||||
|
|
||||||
@@ -350,9 +347,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
|||||||
r = lora_b_stacked.size(-1)
|
r = lora_b_stacked.size(-1)
|
||||||
|
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
buffer = torch.zeros((x.size(0), r),
|
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||||
dtype=torch.float32,
|
|
||||||
device=x.device)
|
|
||||||
|
|
||||||
indices = self.sampler_indices
|
indices = self.sampler_indices
|
||||||
|
|
||||||
|
|||||||
@@ -1,91 +1,75 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig
|
||||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
from vllm.lora.layers import (
|
||||||
MergedColumnParallelLinearWithLoRA,
|
ColumnParallelLinearWithLoRA,
|
||||||
MergedQKVParallelLinearWithLoRA,
|
MergedColumnParallelLinearWithLoRA,
|
||||||
QKVParallelLinearWithLoRA,
|
MergedQKVParallelLinearWithLoRA,
|
||||||
RowParallelLinearWithLoRA,
|
QKVParallelLinearWithLoRA,
|
||||||
VocabParallelEmbeddingWithLoRA)
|
RowParallelLinearWithLoRA,
|
||||||
|
VocabParallelEmbeddingWithLoRA,
|
||||||
|
)
|
||||||
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
|
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
|
||||||
|
|
||||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
from vllm_ascend.ops.linear import (
|
||||||
AscendMergedColumnParallelLinear,
|
AscendColumnParallelLinear,
|
||||||
AscendQKVParallelLinear,
|
AscendMergedColumnParallelLinear,
|
||||||
AscendRowParallelLinear)
|
AscendQKVParallelLinear,
|
||||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
AscendRowParallelLinear,
|
||||||
AscendVocabParallelEmbedding
|
)
|
||||||
|
from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding
|
||||||
|
|
||||||
|
|
||||||
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_replace_layer(
|
def can_replace_layer(
|
||||||
cls,
|
cls,
|
||||||
source_layer: nn.Module,
|
source_layer: nn.Module,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
packed_modules_list: list,
|
packed_modules_list: list,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: PretrainedConfig | None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return type(source_layer) is AscendColumnParallelLinear
|
return type(source_layer) is AscendColumnParallelLinear
|
||||||
|
|
||||||
|
|
||||||
class AscendMergedColumnParallelLinearWithLoRA(
|
class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
||||||
MergedColumnParallelLinearWithLoRA):
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_replace_layer(
|
def can_replace_layer(
|
||||||
cls,
|
cls,
|
||||||
source_layer: nn.Module,
|
source_layer: nn.Module,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
packed_modules_list: list,
|
packed_modules_list: list,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: PretrainedConfig | None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return type(source_layer) is AscendMergedColumnParallelLinear
|
return type(source_layer) is AscendMergedColumnParallelLinear
|
||||||
|
|
||||||
|
|
||||||
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
|
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_replace_layer(
|
def can_replace_layer(
|
||||||
cls,
|
cls,
|
||||||
source_layer: nn.Module,
|
source_layer: nn.Module,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
packed_modules_list: list,
|
packed_modules_list: list,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: PretrainedConfig | None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return type(source_layer) is AscendRowParallelLinear
|
return type(source_layer) is AscendRowParallelLinear
|
||||||
|
|
||||||
|
|
||||||
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_replace_layer(
|
def can_replace_layer(
|
||||||
cls,
|
cls,
|
||||||
source_layer: nn.Module,
|
source_layer: nn.Module,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
packed_modules_list: list,
|
packed_modules_list: list,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: PretrainedConfig | None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return type(source_layer) is AscendVocabParallelEmbedding
|
return type(source_layer) is AscendVocabParallelEmbedding
|
||||||
|
|
||||||
|
|
||||||
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
|
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@_not_fully_sharded_can_replace
|
|
||||||
def can_replace_layer(cls, source_layer: nn.Module,
|
|
||||||
lora_config: LoRAConfig, packed_modules_list: list,
|
|
||||||
model_config: Optional[PretrainedConfig]) -> bool:
|
|
||||||
return type(source_layer) is AscendQKVParallelLinear and len(
|
|
||||||
packed_modules_list) == 1
|
|
||||||
|
|
||||||
|
|
||||||
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_not_fully_sharded_can_replace
|
@_not_fully_sharded_can_replace
|
||||||
def can_replace_layer(
|
def can_replace_layer(
|
||||||
@@ -93,18 +77,28 @@ class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
|||||||
source_layer: nn.Module,
|
source_layer: nn.Module,
|
||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
packed_modules_list: list,
|
packed_modules_list: list,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: PretrainedConfig | None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return (type(source_layer) is AscendQKVParallelLinear
|
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1
|
||||||
and len(packed_modules_list) == 3)
|
|
||||||
|
|
||||||
|
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
||||||
|
@classmethod
|
||||||
|
@_not_fully_sharded_can_replace
|
||||||
|
def can_replace_layer(
|
||||||
|
cls,
|
||||||
|
source_layer: nn.Module,
|
||||||
|
lora_config: LoRAConfig,
|
||||||
|
packed_modules_list: list,
|
||||||
|
model_config: PretrainedConfig | None,
|
||||||
|
) -> bool:
|
||||||
|
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
|
||||||
|
|
||||||
|
|
||||||
def refresh_all_lora_classes():
|
def refresh_all_lora_classes():
|
||||||
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
||||||
vllm.lora.utils._all_lora_classes.add(
|
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
|
||||||
AscendMergedColumnParallelLinearWithLoRA)
|
|
||||||
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
|
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
|
||||||
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||||
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
|
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
|
||||||
vllm.lora.utils._all_lora_classes.add(
|
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)
|
||||||
AscendMergedQKVParallelLinearWithLoRA)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user