[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #5) (#5996)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|
`.../distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py`
|
|
`vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py` |
| `
vllm_ascend/distributed/kv_transfer/utils/mooncake_transfer_engine.py` |
| ` vllm_ascend/distributed/kv_transfer/utils/utils.py` |
| ` vllm_ascend/kv_offload/cpu_npu.py` |
| ` vllm_ascend/kv_offload/npu.py` |
| ` vllm_ascend/lora/lora_ops.py` |
| ` vllm_ascend/lora/punica_npu.py` |
| ` vllm_ascend/lora/utils.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-24 22:45:38 +08:00
committed by GitHub
parent 7faa6878a6
commit 6ccccad102
21 changed files with 866 additions and 1034 deletions

View File

@@ -1,6 +1,8 @@
[mypy]
; warn_return_any = True
warn_unused_configs = True
; disable errors about unchecked annotations for now.
disable_error_code = annotation-unchecked
; Suppress all missing import errors from torch_npu for mypy.
[mypy-torch_npu.*]
@@ -31,4 +33,4 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-ucm.*]
ignore_missing_imports = True
ignore_missing_imports = True

View File

@@ -51,11 +51,6 @@ line-length = 120
# Folder to be modified
exclude = [
"tests/**",
# (5)
"vllm_ascend/distributed/kv_transfer/kv_pool/**",
"vllm_ascend/distributed/kv_transfer/utils/**",
"vllm_ascend/kv_offload/**",
"vllm_ascend/lora/**",
# (7)
"vllm_ascend/quantization/**",
"vllm_ascend/sample/*.py",

View File

@@ -1,11 +1,10 @@
import threading
from typing import Any, Optional
from typing import Any
import torch
import zmq
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
from vllm.forward_context import ForwardContext
from vllm.logger import logger
from vllm.utils.network_utils import make_zmq_socket
@@ -17,31 +16,27 @@ from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler import (
KVPoolScheduler, get_zmq_rpc_path_lookup)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import \
KVPoolWorker
KVPoolScheduler,
get_zmq_rpc_path_lookup,
)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import KVPoolWorker
class AscendStoreConnector(KVConnectorBase_V1):
def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
super().__init__(vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config)
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None):
super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False)
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get("use_layerwise", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
"consumer_is_to_put", False
)
connector_name = vllm_config.kv_transfer_config.kv_connector
if connector_name == "MooncakeConnectorStoreV1":
logger.warning(
"It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future."
"It is recommended to use the AscendStoreConnector, "
"as the MoonCakeStoreConnector will be removed in the future."
)
self.kv_caches: dict[str, torch.Tensor] = {}
@@ -49,8 +44,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
self.sended_but_unfinished_reqs: set[str] = set()
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = KVPoolScheduler(vllm_config,
self.use_layerwise)
self.connector_scheduler = KVPoolScheduler(vllm_config, self.use_layerwise)
else:
self.connector_worker = KVPoolWorker(
vllm_config,
@@ -59,27 +53,19 @@ class AscendStoreConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
if vllm_config.parallel_config.rank == 0:
self.lookup_server = LookupKeyServer(self.connector_worker,
vllm_config,
self.use_layerwise)
self.lookup_server = LookupKeyServer(self.connector_worker, vllm_config, self.use_layerwise)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
return self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens)
def build_connector_meta(
self,
@@ -92,7 +78,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
@@ -103,8 +89,7 @@ class AscendStoreConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
self.connector_worker.start_load_kv(self._get_connector_metadata())
@@ -113,8 +98,9 @@ class AscendStoreConnector(KVConnectorBase_V1):
return
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
def save_kv_layer(
self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
) -> None:
if not self.use_layerwise:
return
@@ -133,17 +119,16 @@ class AscendStoreConnector(KVConnectorBase_V1):
self.connector_worker.wait_for_save(self._get_connector_metadata())
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
done_sending, done_recving = self.connector_worker.get_finished(
finished_req_ids, self._get_connector_metadata())
finished_req_ids, self._get_connector_metadata()
)
return done_sending, done_recving
class LookupKeyServer:
def __init__(
self,
pool_worker: KVPoolWorker,
@@ -171,8 +156,7 @@ class LookupKeyServer:
token_len = int.from_bytes(all_frames[0], byteorder="big")
hash_frames = all_frames[1:]
hashes_str = self.decoder.decode(hash_frames)
result = self.pool_worker.lookup_scheduler(
token_len, hashes_str, self.use_layerwise)
result = self.pool_worker.lookup_scheduler(token_len, hashes_str, self.use_layerwise)
response = result.to_bytes(4, "big")
self.socket.send(response)

View File

@@ -4,13 +4,15 @@ from vllm.config import ParallelConfig
class Backend(ABC):
@abstractmethod
def __init__(self, parallel_config: ParallelConfig):
pass
@abstractmethod
def set_device(self):
pass
@abstractmethod
def register_buffer(self, ptrs: list[int], lengths: list[int]):
pass
@@ -19,11 +21,9 @@ class Backend(ABC):
pass
@abstractmethod
def put(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
pass
@abstractmethod
def get(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
pass

View File

@@ -5,8 +5,7 @@ import torch
from vllm.config import ParallelConfig
from vllm.logger import logger
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
Backend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
@@ -18,7 +17,6 @@ class MmcDirect(Enum):
class MemcacheBackend(Backend):
def __init__(self, parallel_config: ParallelConfig):
try:
from memcache_hybrid import DistributedObjectStore # type: ignore
@@ -26,21 +24,17 @@ class MemcacheBackend(Backend):
raise ImportError(
"Please install memcache by following the instructions at "
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
"to run vLLM with MemcacheConnector.") from e
"to run vLLM with MemcacheConnector."
) from e
try:
soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
import torch
from vllm.distributed import get_world_group
tmp_tensor = torch.zeros(1, device="npu")
output_tensor_list = [
torch.empty_like(tmp_tensor)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(
output_tensor_list,
tmp_tensor,
group=get_world_group().device_group)
output_tensor_list = [torch.empty_like(tmp_tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensor_list, tmp_tensor, group=get_world_group().device_group)
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
@@ -54,8 +48,7 @@ class MemcacheBackend(Backend):
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error(
"An error occurred while loading the configuration: %s", exc)
logger.error("An error occurred while loading the configuration: %s", exc)
raise
def set_device(self):
@@ -73,22 +66,18 @@ class MemcacheBackend(Backend):
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def get(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
def get(self, key: list[str], addr: list[list[int]], size: list[list[int]]):
try:
res = self.store.batch_get_into_layers(key, addr, size,
MmcDirect.COPY_G2L.value)
res = self.store.batch_get_into_layers(key, addr, size, MmcDirect.COPY_G2L.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to get key {key}. {e}")
def put(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
def put(self, key: list[str], addr: list[list[int]], size: list[list[int]]):
try:
res = self.store.batch_put_from_layers(key, addr, size,
MmcDirect.COPY_L2G.value)
res = self.store.batch_put_from_layers(key, addr, size, MmcDirect.COPY_L2G.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")

View File

@@ -2,10 +2,9 @@
import json
import os
import re
import torch
from dataclasses import dataclass
from typing import Union
import torch
# Third Party
from mooncake.store import ReplicateConfig # type: ignore
@@ -13,17 +12,14 @@ from vllm.config import ParallelConfig
from vllm.logger import logger
from vllm.utils.network_utils import get_ip
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
Backend
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import \
global_te
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
class MooncakeBackend(Backend):
def __init__(self, parallel_config: ParallelConfig):
try:
from mooncake.store import MooncakeDistributedStore # type: ignore
@@ -31,23 +27,25 @@ class MooncakeBackend(Backend):
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
"to run vLLM with MooncakeConnector."
) from e
self.config = MooncakeStoreConfig.load_from_env()
self.store = MooncakeDistributedStore()
self.rank = parallel_config.rank
if self.config.protocol == "ascend":
local_hostname = get_ip()
transfer_engine = global_te.get_transfer_engine(local_hostname,
device_name=None)
self.local_seg = local_hostname + ":" + str(
transfer_engine.get_rpc_port())
ret = self.store.setup(self.local_seg, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
transfer_engine.get_engine())
transfer_engine = global_te.get_transfer_engine(local_hostname, device_name=None)
self.local_seg = local_hostname + ":" + str(transfer_engine.get_rpc_port())
ret = self.store.setup(
self.local_seg,
self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
transfer_engine.get_engine(),
)
if ret != 0:
msg = "Initialize mooncake failed."
logger.error(msg)
@@ -63,25 +61,21 @@ class MooncakeBackend(Backend):
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def put(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
try:
config = ReplicateConfig()
config.preferred_segment = self.local_seg
config.prefer_alloc_in_same_node = True
res = self.store.batch_put_from_multi_buffers(
keys, addrs, sizes, config)
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes, config)
for value in res:
if value < 0:
logger.error(f"Failed to put key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {keys},error:{e}")
def get(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
try:
res = self.store.batch_get_into_multi_buffers(
keys, addrs, sizes, True)
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes, True)
for value in res:
if value < 0:
logger.error(f"Failed to get key {keys}, res:{res}")
@@ -92,7 +86,7 @@ class MooncakeBackend(Backend):
@dataclass
class MooncakeStoreConfig:
metadata_server: str
global_segment_size: Union[int, str]
global_segment_size: int | str
local_buffer_size: int
protocol: str
device_name: str
@@ -105,33 +99,32 @@ class MooncakeStoreConfig:
return MooncakeStoreConfig(
metadata_server=config.get("metadata_server"),
global_segment_size=_parse_global_segment_size(
config.get("global_segment_size",
DEFAULT_GLOBAL_SEGMENT_SIZE)),
local_buffer_size=_parse_global_segment_size(
config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)),
config.get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE)
),
local_buffer_size=_parse_global_segment_size(config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)),
protocol=config.get("protocol", "ascend"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"))
master_server_address=config.get("master_server_address"),
)
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
if not config_path:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeStoreConfig.from_file(config_path)
def _parse_global_segment_size(value) -> int:
"""
Parse storage size strings with support for units: GB, MB, KB, B
Args:
value: Input value (int, str, or other convertible types)
Returns:
int: Size in bytes
Raises:
ValueError: For invalid format, missing number, or negative values
TypeError: For unsupported input types
@@ -143,54 +136,50 @@ def _parse_global_segment_size(value) -> int:
try:
return int(value)
except (TypeError, ValueError) as e:
raise TypeError(
f"Unsupported type for global_segment_size: {type(value)}"
) from e
raise TypeError(f"Unsupported type for global_segment_size: {type(value)}") from e
cleaned_input = value.strip().lower()
if not cleaned_input:
raise ValueError("global segment size cannot be empty.")
UNIT_MULTIPLIERS = {
'gb': 1024**3, # 1 GB = 1024^3 bytes
'mb': 1024**2, # 1 MB = 1024^2 bytes
'kb': 1024, # 1 KB = 1024 bytes
'b': 1 # 1 B = 1 byte
"gb": 1024**3, # 1 GB = 1024^3 bytes
"mb": 1024**2, # 1 MB = 1024^2 bytes
"kb": 1024, # 1 KB = 1024 bytes
"b": 1, # 1 B = 1 byte
}
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
pattern = r"^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$"
match = re.match(pattern, cleaned_input)
if not match:
raise ValueError(f"Invalid format: '{value}'")
number_str = match.group(1)
unit = match.group(2) or 'b'
unit = match.group(2) or "b"
multiplier = UNIT_MULTIPLIERS[unit]
return _convert_to_bytes(number_str, multiplier, value)
def _convert_to_bytes(number_str: str, multiplier: int,
original_input: str) -> int:
def _convert_to_bytes(number_str: str, multiplier: int, original_input: str) -> int:
"""
Convert numeric string to byte count
Args:
number_str: Numeric portion of input
multiplier: Unit conversion factor
original_input: Original input string (for error messages)
Returns:
int: Byte count
Raises:
ValueError: For invalid numbers or negative results
"""
try:
numeric_value = float(number_str)
except ValueError:
raise ValueError(
f"Invalid numeric value '{number_str}' in: '{original_input}'")
raise ValueError(f"Invalid numeric value '{number_str}' in: '{original_input}'")
# Calculate byte count
try:
byte_count = int(numeric_value * multiplier)

View File

@@ -1,16 +1,16 @@
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union
from typing import Optional
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.logger import logger
from vllm.utils.math_utils import cdiv
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import NewRequestData
#Parameters related to the key
# Parameters related to the key
@dataclass
class KeyMetadata:
"""name of the LLM model"""
@@ -32,23 +32,26 @@ class PoolKey:
chunk_hash: str
def __hash__(self):
return hash((
self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank,
self.key_metadata.pcp_rank,
self.key_metadata.dcp_rank,
self.key_metadata.pp_rank,
self.chunk_hash,
))
return hash(
(
self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank,
self.key_metadata.pcp_rank,
self.key_metadata.dcp_rank,
self.key_metadata.pp_rank,
self.chunk_hash,
)
)
def to_string(self):
return (
f"{self.key_metadata.model_name}"
f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}"
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}"
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}")
f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}"
)
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
def split_layers(self, num_layers: int) -> list["LayerPoolKey"]:
"""Split the key into multiple keys for each layer"""
keys = []
for layer_id in range(num_layers):
@@ -57,7 +60,8 @@ class PoolKey:
self.key_metadata,
self.chunk_hash,
layer_id,
))
)
)
return keys
@@ -68,14 +72,16 @@ class LayerPoolKey(PoolKey):
layer_id: int
def __hash__(self):
return hash((
self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank,
self.key_metadata.pcp_rank,
self.key_metadata.dcp_rank,
self.chunk_hash,
self.layer_id,
))
return hash(
(
self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank,
self.key_metadata.pcp_rank,
self.key_metadata.dcp_rank,
self.chunk_hash,
self.layer_id,
)
)
def to_string(self):
return (
@@ -85,10 +91,8 @@ class LayerPoolKey(PoolKey):
)
class ChunkedTokenDatabase():
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool,
partitions: Optional[List[int]]):
class ChunkedTokenDatabase:
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, partitions: list[int] | None):
self.metadata = metadata
self.block_size = block_size
self.use_mla = use_mla
@@ -96,9 +100,7 @@ class ChunkedTokenDatabase():
self.block_len: list[int] = []
self.partitions = partitions
def _make_key_by_hash(self,
chunk_hash: str,
layer_id: Optional[int] = None):
def _make_key_by_hash(self, chunk_hash: str, layer_id: int | None = None):
assert self.metadata is not None
return PoolKey(
self.metadata,
@@ -116,8 +118,7 @@ class ChunkedTokenDatabase():
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])
block_len = self.block_len[index % 2] if self.use_mla else self.block_len[0]
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
@@ -125,22 +126,17 @@ class ChunkedTokenDatabase():
size_list.append(length)
return addr_list, size_list, block_id
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
layer_id: int):
def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int):
block_id = block_ids[start // self.block_size]
if self.use_mla:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[1]
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[1]
length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start))
size_list = [length_k, length_v]
else:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[0]
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length]
addr_list = [addr_k, addr_v]
@@ -149,9 +145,9 @@ class ChunkedTokenDatabase():
def process_tokens(
self,
token_len: int,
block_hashes: Union[list[BlockHash], list[str]],
block_hashes: list[BlockHash] | list[str],
mask_num: int = 0,
) -> Iterable[Tuple[int, int, PoolKey]]:
) -> Iterable[tuple[int, int, PoolKey]]:
"""Process the tokens and return the corresponding cache engine keys.
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
@@ -202,10 +198,10 @@ class ChunkedTokenDatabase():
start = 0
for j, part in enumerate(self.partitions):
# part * 2 because addr and size contain both k and v
end = len(addr_list) if j == len(
self.partitions) - 1 else start + part * 2
end = len(addr_list) if j == len(self.partitions) - 1 else start + part * 2
new_str = key[i].replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{j}", 1)
"@pp_rank:0", f"@pp_rank:{j}", 1
)
new_key.append(new_str)
new_addr.append(addr_list[start:end])
new_size.append(size_list[start:end])
@@ -213,7 +209,7 @@ class ChunkedTokenDatabase():
return new_key, new_addr, new_size
#Parameters related to the connector metadata
# Parameters related to the connector metadata
@dataclass
class LoadSpec:
# Number of tokens cached in vLLM
@@ -273,7 +269,7 @@ class RequestTracker:
def update(
self,
new_block_ids: Union[tuple[list[int], ...], list[int]],
new_block_ids: tuple[list[int], ...] | list[int],
) -> None:
"""Update the request tracker when a running request is
scheduled again
@@ -286,8 +282,7 @@ class RequestTracker:
elif isinstance(new_block_ids, list):
pass
else:
raise ValueError(
f"Unsupported new_block_ids type {type(new_block_ids)}")
raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}")
self.allocated_block_ids.extend(new_block_ids)
@@ -302,22 +297,22 @@ class ReqMeta:
block_hashes: list[BlockHash]
can_save: Optional[bool] = None
can_save: bool | None = None
# load_spec
load_spec: Optional[LoadSpec] = None
load_spec: LoadSpec | None = None
is_last_chunk: Optional[bool] = None
is_last_chunk: bool | None = None
current_event: Optional[torch.npu.Event] = None
current_event: torch.npu.Event | None = None
@staticmethod
def from_request_tracker(
tracker: RequestTracker,
block_size: int,
load_spec: Optional[LoadSpec] = None,
skip_save: Optional[bool] = False,
block_hashes: list[BlockHash] = [],
is_last_chunk: Optional[bool] = None,
load_spec: LoadSpec | None = None,
skip_save: bool | None = False,
block_hashes: list[BlockHash] | None = None,
is_last_chunk: bool | None = None,
discard_partial_chunks: bool = True,
) -> Optional["ReqMeta"]:
"""Create the request metadata from a request tracker.
@@ -333,17 +328,17 @@ class ReqMeta:
the request metadata if we need to perform load/save
operations, None otherwise.
"""
if block_hashes is None:
block_hashes = []
input_token_len = tracker.token_len
# For save operation: do not save if the following condition is met
# 1. has already been saved before (num_saved_tokens > 0)
# 2. number of unsaved tokens is not reached the chunk boundary
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
block_size if discard_partial_chunks else 0)
chunk_boundary = cdiv(tracker.num_saved_tokens + 1, block_size) * block_size if discard_partial_chunks else 0
# Calculate number of tokens to save based on discard_partial_chunks
# setting
num_tokens_to_save = ((input_token_len // block_size * block_size)
if discard_partial_chunks else input_token_len)
num_tokens_to_save = (input_token_len // block_size * block_size) if discard_partial_chunks else input_token_len
skip_save = skip_save or num_tokens_to_save < chunk_boundary
if skip_save and load_spec is None:
@@ -363,9 +358,7 @@ class ReqMeta:
else:
# Do not load if not in `can_load` state
load_spec = None
logger.debug(
f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}"
)
logger.debug(f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}")
return ReqMeta(
req_id=tracker.req_id,
token_len_chunk=num_tokens_to_save,
@@ -378,7 +371,6 @@ class ReqMeta:
class AscendConnectorMetadata(KVConnectorMetadata):
def __init__(self, unfinished_request_ids, preempted_req_ids):
self.requests = []
self.unfinished_request_ids = unfinished_request_ids
@@ -396,10 +388,10 @@ class AscendConnectorMetadata(KVConnectorMetadata):
@dataclass
class LasyerMultiBlockReqMeta:
req_id: str
keys: List[LayerPoolKey]
starts: List[int]
keys: list[LayerPoolKey]
starts: list[int]
ends: list[int]
block_ids: list[int]
layer_id: int
is_last_chunk: Optional[bool] = True
current_event: Optional[torch.npu.Event] = None
is_last_chunk: bool | None = True
current_event: torch.npu.Event | None = None

View File

@@ -7,8 +7,7 @@ from typing import Any
import torch
from vllm.logger import logger
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
Backend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
# isort: off
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
@@ -20,10 +19,16 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import
class KVTransferThread(threading.Thread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event, name: str):
def __init__(
self,
m_store: Backend,
token_database: ChunkedTokenDatabase,
block_size: int,
tp_rank: int,
dcp_size: int,
ready_event: threading.Event,
name: str,
):
super().__init__(daemon=True, name=name)
self.m_store = m_store
self.ready_event = ready_event
@@ -39,7 +44,7 @@ class KVTransferThread(threading.Thread):
def add_request(
self,
request: ReqMeta,
request: ReqMeta | LasyerMultiBlockReqMeta,
) -> torch.Tensor:
self.request_queue.put(request)
@@ -98,17 +103,20 @@ class KVTransferThread(threading.Thread):
class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
kv_role: str, ready_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheSendingThread")
def __init__(
self,
m_store: Backend,
token_database: ChunkedTokenDatabase,
block_size: int,
tp_rank: int,
dcp_size: int,
put_step: int,
kv_role: str,
ready_event: threading.Event,
):
super().__init__(
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheSendingThread"
)
self.put_step = put_step
self.kv_role = kv_role
self.stored_requests = defaultdict[str, int](int)
@@ -139,16 +147,15 @@ class KVCacheStoreSendingThread(KVTransferThread):
self.request_queue.task_done()
return
for start, end, key in self.token_database.process_tokens(
token_len, req_meta.block_hashes):
for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes):
starts.append(start)
ends.append(end)
keys.append(key.to_string())
if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step]
starts = starts[self.tp_rank % self.put_step :: self.put_step]
ends = ends[self.tp_rank % self.put_step :: self.put_step]
keys = keys[self.tp_rank % self.put_step :: self.put_step]
if not keys:
self.dec_stored_request(req_id)
@@ -165,8 +172,7 @@ class KVCacheStoreSendingThread(KVTransferThread):
keys = keys[skip_block_num:]
logger.info(
"Storing KV cache for %d out of %d blocks "
"(skip_block_num=%d) for request %s",
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
len(keys),
token_len // self.block_size,
skip_block_num,
@@ -183,14 +189,12 @@ class KVCacheStoreSendingThread(KVTransferThread):
addrs = []
sizes = []
for index, start in enumerate(starts):
addr, size, _ = self.token_database.prepare_value(
start, ends[index], block_ids)
addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids)
addrs.append(addr)
sizes.append(size)
if self.kv_role == "kv_consumer":
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(
keys, addrs, sizes)
keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(keys, addrs, sizes)
if current_event is not None:
current_event.synchronize()
@@ -201,69 +205,69 @@ class KVCacheStoreSendingThread(KVTransferThread):
class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheStoreRecvingThread")
def __init__(
self,
m_store: Backend,
token_database: ChunkedTokenDatabase,
block_size: int,
tp_rank: int,
dcp_size: int,
ready_event: threading.Event,
):
super().__init__(
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreRecvingThread"
)
def _handle_request(self, req_meta: ReqMeta):
token_len = req_meta.load_spec.token_len # type: ignore[union-attr]
req_id = req_meta.req_id
mask_num = (
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
// self.block_size * self.block_size)
// self.block_size
* self.block_size
)
addr_list = []
size_list = []
key_list = []
for start, end, key in self.token_database.process_tokens(
token_len, req_meta.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, req_meta.block_ids)
for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(start, end, req_meta.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank %
len(key_list):] + key_list[:self.tp_rank %
len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list):] + size_list[:self.tp_rank %
len(size_list)]
key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
addr_list_c = addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
size_list_c = size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c)
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int, put_step: int,
ready_event: threading.Event, num_layers: int):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheStoreLayerSendingThread")
def __init__(
self,
m_store: Backend,
token_database: ChunkedTokenDatabase,
block_size: int,
tp_rank: int,
dcp_size: int,
put_step: int,
ready_event: threading.Event,
num_layers: int,
):
super().__init__(
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerSendingThread"
)
self.final_layer_id = num_layers - 1
self.put_step = put_step
def add_request( # type: ignore[override]
self, req_meta: ReqMeta) -> torch.Tensor:
self, req_meta: ReqMeta
) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
self, req_meta: LasyerMultiBlockReqMeta
):
starts = req_meta.starts
ends = req_meta.ends
keys = req_meta.keys
@@ -272,9 +276,9 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
total_block = len(keys)
is_last_chunk = req_meta.is_last_chunk
if not self.dcp_size > 1:
starts = starts[self.tp_rank % self.put_step::self.put_step]
ends = ends[self.tp_rank % self.put_step::self.put_step]
keys = keys[self.tp_rank % self.put_step::self.put_step]
starts = starts[self.tp_rank % self.put_step :: self.put_step]
ends = ends[self.tp_rank % self.put_step :: self.put_step]
keys = keys[self.tp_rank % self.put_step :: self.put_step]
if not keys:
if is_last_chunk:
@@ -300,7 +304,8 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
size_list = []
for index, key in enumerate(key_list):
addr, size = self.token_database.prepare_value_layer(
starts[index], ends[index], req_meta.block_ids, layer_id)
starts[index], ends[index], req_meta.block_ids, layer_id
)
addr_list.append(addr)
size_list.append(size)
@@ -313,8 +318,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
self.request_queue.task_done()
logger.info(
"Storing KV cache for %d out of %d blocks "
"(skip_block_num=%d) for request %s",
"Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s",
len(keys),
total_block,
skip_block_num,
@@ -323,44 +327,42 @@ class KVCacheStoreLayerSendingThread(KVTransferThread):
class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
block_size: int, tp_rank: int, dcp_size: int,
ready_event: threading.Event, get_event: threading.Event):
super().__init__(m_store,
token_database,
block_size,
tp_rank,
dcp_size,
ready_event,
name="KVCacheStoreLayerRecvingThread")
def __init__(
self,
m_store: Backend,
token_database: ChunkedTokenDatabase,
block_size: int,
tp_rank: int,
dcp_size: int,
ready_event: threading.Event,
get_event: threading.Event,
):
super().__init__(
m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerRecvingThread"
)
self.get_event = get_event
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self, req_meta: LasyerMultiBlockReqMeta
) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
self, req_meta: LasyerMultiBlockReqMeta
):
addr_list = []
size_list = []
key_list = []
for index, key in enumerate(req_meta.keys):
addr, size = self.token_database.prepare_value_layer(
req_meta.starts[index], req_meta.ends[index],
req_meta.block_ids, req_meta.layer_id)
req_meta.starts[index], req_meta.ends[index], req_meta.block_ids, req_meta.layer_id
)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank %
len(key_list):] + key_list[:self.tp_rank %
len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list):] + size_list[:self.tp_rank %
len(size_list)]
key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
addr_list_c = addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
size_list_c = size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c)
self.request_queue.task_done()

View File

@@ -1,10 +1,9 @@
from typing import Any, Optional
from typing import Any
import vllm.envs as envs
import zmq
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.logger import logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
@@ -14,27 +13,29 @@ from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackEncoder
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker)
AscendConnectorMetadata,
LoadSpec,
ReqMeta,
RequestTracker,
)
class KVPoolScheduler:
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
self.use_layerwise = use_layerwise
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_load", False)
"consumer_is_to_load", False
)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
"consumer_is_to_put", False
)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get("load_async", False)
self.client = LookupKeyClient(vllm_config)
# request_id -> (vllm cached tokes, kvpool cached tokens)
self.load_specs: dict[str, LoadSpec] = {}
self.pcp_size = getattr(vllm_config.parallel_config,
"prefill_context_parallel_size", 1)
self.dcp_size = getattr(vllm_config.parallel_config,
"decode_context_parallel_size", 1)
self.pcp_size = getattr(vllm_config.parallel_config, "prefill_context_parallel_size", 1)
self.dcp_size = getattr(vllm_config.parallel_config, "decode_context_parallel_size", 1)
self._block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1:
@@ -45,9 +46,9 @@ class KVPoolScheduler:
self._request_trackers: dict[str, RequestTracker] = {}
self._preempted_req_ids: set[str] = set()
# Whether to discard partial chunks
self._discard_partial_chunks = (
vllm_config.kv_transfer_config.get_from_extra_config(
"discard_partial_chunks", True))
self._discard_partial_chunks = vllm_config.kv_transfer_config.get_from_extra_config(
"discard_partial_chunks", True
)
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
self._unfinished_request_ids: set[str] = set()
@@ -72,13 +73,11 @@ class KVPoolScheduler:
return 0, False
if self._discard_partial_chunks:
token_len = len(request.prompt_token_ids
) // self._block_size * self._block_size
token_len = len(request.prompt_token_ids) // self._block_size * self._block_size
else:
token_len = len(request.prompt_token_ids)
num_external_hit_tokens = self.client.lookup(token_len,
request.block_hashes)
num_external_hit_tokens = self.client.lookup(token_len, request.block_hashes)
if num_external_hit_tokens == request.num_tokens:
num_external_hit_tokens -= 1
@@ -107,9 +106,7 @@ class KVPoolScheduler:
return need_to_allocate, self.load_async and not self.use_layerwise
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
"""
Update KVConnector state after temporary buffer alloc.
@@ -120,8 +117,7 @@ class KVPoolScheduler:
if num_external_tokens > 0:
local_block_ids = blocks.get_block_ids()[0]
self._unfinished_requests[request.request_id] = (request,
local_block_ids)
self._unfinished_requests[request.request_id] = (request, local_block_ids)
self._unfinished_request_ids.add(request.request_id)
if request.request_id not in self.load_specs:
# No KV tokens from external KV cache, return
@@ -133,18 +129,20 @@ class KVPoolScheduler:
return
assert (
num_external_tokens > 0 and num_external_tokens
== self.load_specs[request.request_id].kvpool_cached_tokens -
self.load_specs[request.request_id].vllm_cached_tokens
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
num_external_tokens > 0
and num_external_tokens
== self.load_specs[request.request_id].kvpool_cached_tokens
- self.load_specs[request.request_id].vllm_cached_tokens
), (
f"Mismatch in number of tokens: {num_external_tokens} vs "
f"{self.load_specs[request.request_id].kvpool_cached_tokens} - "
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
f" for request {request.request_id}")
f" for request {request.request_id}"
)
self.load_specs[request.request_id].can_load = True
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""Attach the connector metadata to the request object.
This function should NOT modify other fields in the scheduler_output
@@ -155,14 +153,13 @@ class KVPoolScheduler:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
force_skip_save = (self.kv_role == "kv_consumer"
and not self.consumer_is_to_put)
force_skip_save = self.kv_role == "kv_consumer" and not self.consumer_is_to_put
for finished_req_id in scheduler_output.finished_req_ids:
self._request_trackers.pop(finished_req_id, None)
self._unfinished_requests.pop(finished_req_id, None)
self._unfinished_request_ids.discard(finished_req_id)
for req_id in scheduler_output.preempted_req_ids:
self._preempted_req_ids.update(scheduler_output.preempted_req_ids)
self._request_trackers.pop(req_id, None)
@@ -173,9 +170,7 @@ class KVPoolScheduler:
for request in scheduler_output.scheduled_new_reqs:
# Right now, we only load KV for new requests
load_spec = self.load_specs.pop(request.req_id, None)
num_tokens_to_compute = (
request.num_computed_tokens +
scheduler_output.num_scheduled_tokens[request.req_id])
num_tokens_to_compute = request.num_computed_tokens + scheduler_output.num_scheduled_tokens[request.req_id]
request_tuple = self._unfinished_requests.get(request.req_id)
request_real = request_tuple[0] # type: ignore[index]
if not isinstance(request.block_ids[0], list):
@@ -183,25 +178,25 @@ class KVPoolScheduler:
else:
unfolded_block_ids = request.block_ids[0].copy()
request_tracker = RequestTracker(
req_id=request.req_id,
token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
)
req_id=request.req_id,
token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
)
self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else len(
request.prompt_token_ids))
last_chunk_tokens_num = (
(len(request.prompt_token_ids) // self._block_size * self._block_size)
if self._discard_partial_chunks
else len(request.prompt_token_ids)
)
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=force_skip_save,
block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
@@ -224,8 +219,8 @@ class KVPoolScheduler:
request_tuple = self._unfinished_requests.get(req_id)
request_real = request_tuple[0] # type: ignore[index]
num_tokens_to_compute = (
request_real.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
request_real.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]
)
request_tracker = RequestTracker(
req_id=req_id,
token_len=num_tokens_to_compute,
@@ -233,21 +228,21 @@ class KVPoolScheduler:
num_saved_tokens=0,
)
self._request_trackers[req_id] = request_tracker
last_chunk_tokens_num = ((len(request_real.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else len(
request_real.prompt_token_ids))
last_chunk_tokens_num = (
(len(request_real.prompt_token_ids) // self._block_size * self._block_size)
if self._discard_partial_chunks
else len(request_real.prompt_token_ids)
)
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=force_skip_save,
block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
# decode/chunked request
else:
request_tracker = self._request_trackers[req_id]
@@ -256,48 +251,44 @@ class KVPoolScheduler:
if req_tuple:
request = req_tuple[0]
num_current_tokens = request_tracker.token_len
new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens]
new_token_ids = request.all_token_ids[num_current_tokens : num_current_tokens + num_new_tokens]
request_tracker.token_len += len(new_token_ids)
else:
raise ValueError(
f"Request {req_id} is not in _unfinished_requests, "
f"but it is scheduled to be cached")
f"Request {req_id} is not in _unfinished_requests, but it is scheduled to be cached"
)
num_computed_token = cached_reqs.num_computed_tokens[i]
if num_computed_token >= len(request.prompt_token_ids):
continue
request_tracker.update(new_block_ids)
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(request.prompt_token_ids))
last_chunk_tokens_num = (
(len(request.prompt_token_ids) // self._block_size * self._block_size)
if self._discard_partial_chunks
else len(request.prompt_token_ids)
)
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
block_hashes=request.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
request_ids = [
req.req_id for req in scheduler_output.scheduled_new_reqs
]
for request_id, (request,
block_ids) in self._unfinished_requests.items():
request_ids = [req.req_id for req in scheduler_output.scheduled_new_reqs]
for request_id, (request, block_ids) in self._unfinished_requests.items():
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
load_spec = self.load_specs.pop(request_id, None)
if not load_spec:
continue
num_tokens_to_compute = load_spec.kvpool_cached_tokens
if (num_tokens_to_compute % self._block_size
!= 0) and (num_tokens_to_compute
== len(request.prompt_token_ids) - 1):
if (num_tokens_to_compute % self._block_size != 0) and (
num_tokens_to_compute == len(request.prompt_token_ids) - 1
):
num_tokens_to_compute = num_tokens_to_compute + 1
request_tracker = RequestTracker(
req_id=request_id,
@@ -324,7 +315,7 @@ class KVPoolScheduler:
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
) -> tuple[bool, dict[str, Any] | None]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
@@ -336,13 +327,11 @@ class KVPoolScheduler:
return False, None
delay_free_blocks = len(block_ids) > 0
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(block_ids), request.request_id)
logger.info("Delaying free of %d blocks for request %s", len(block_ids), request.request_id)
return delay_free_blocks, None
class LookupKeyClient:
def __init__(self, vllm_config: "VllmConfig"):
self.encoder = MsgpackEncoder()
self.ctx = zmq.Context() # type: ignore[attr-defined]

View File

@@ -1,37 +1,45 @@
import math
import threading
from typing import Dict, Generator, Optional, Type
from collections.abc import Callable, Generator
import torch
from vllm.config import VllmConfig
from vllm.distributed import (get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_pcp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import (
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_pcp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import logger
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \
Backend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import \
MemcacheBackend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import \
MooncakeBackend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import MemcacheBackend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import MooncakeBackend
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import (
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
LasyerMultiBlockReqMeta, ReqMeta)
AscendConnectorMetadata,
ChunkedTokenDatabase,
KeyMetadata,
LasyerMultiBlockReqMeta,
ReqMeta,
)
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
KVCacheStoreLayerRecvingThread,
KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread,
KVCacheStoreSendingThread,
KVTransferThread,
)
backend_map: Dict[str, Type[Backend]] = {
backend_map: dict[str, Callable[..., Backend]] = {
"mooncake": MooncakeBackend,
"memcache": MemcacheBackend,
}
class KVPoolWorker:
#The main class for the cache engine.
# The main class for the cache engine.
def __init__(
self,
@@ -42,9 +50,7 @@ class KVPoolWorker:
parallel_config = vllm_config.parallel_config
self.dp_rank = parallel_config.data_parallel_rank
self.use_mla = False
if (hasattr(model_config, "use_mla")
and isinstance(model_config.use_mla, bool)
and model_config.use_mla):
if hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla:
self.use_mla = True
self.use_layerwise = use_layerwize
self.tp_rank = get_tensor_model_parallel_rank()
@@ -53,19 +59,16 @@ class KVPoolWorker:
self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_size > 1 else 0
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get("load_async", False)
self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_put", False)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"backend", "mooncake")
"consumer_is_to_put", False
)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get("backend", "mooncake")
self.block_size = vllm_config.cache_config.block_size
if self.pcp_size > 1:
@@ -88,7 +91,7 @@ class KVPoolWorker:
self.put_step = 1
self.metadata = KeyMetadata(
model_config.model.rstrip('/').split('/')[-1],
model_config.model.rstrip("/").split("/")[-1],
self.head_or_tp_rank,
self.pcp_rank,
self.dcp_rank,
@@ -99,40 +102,28 @@ class KVPoolWorker:
if self.kv_role == "kv_consumer" and self.consumer_is_to_put:
num_hidden_layers = model_config.hf_text_config.num_hidden_layers
partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_layer_partition", None)
prefill_pp_size = int(
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"prefill_pp_size", 1))
"prefill_pp_layer_partition", None
)
prefill_pp_size = int(vllm_config.kv_transfer_config.kv_connector_extra_config.get("prefill_pp_size", 1))
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
partitions = [int(layer) for layer in partition_list_str.split(",")]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
raise ValueError("Invalid partition string: {}".format(partition_list_str)) from err
if len(partitions) != prefill_pp_size:
raise ValueError(
f"{len(partitions)=} does not match {prefill_pp_size=}."
)
raise ValueError(f"{len(partitions)=} does not match {prefill_pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}."
)
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
else:
layers_per_partition = num_hidden_layers // prefill_pp_size
partitions = [
layers_per_partition for _ in range(prefill_pp_size)
]
partitions = [layers_per_partition for _ in range(prefill_pp_size)]
if remaining_layers := num_hidden_layers % prefill_pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
self.token_database = ChunkedTokenDatabase(self.metadata,
self.block_size,
self.use_mla, partitions)
self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions)
real_backend = backend_map.get(self.backend.lower())
@@ -142,10 +133,11 @@ class KVPoolWorker:
self.put_step = 1
self.m_store = real_backend( # type: ignore[misc]
parallel_config)
parallel_config
)
self.kv_send_thread: Optional[KVTransferThread] = None
self.kv_recv_thread: Optional[KVTransferThread] = None
self.kv_send_thread: KVTransferThread | None = None
self.kv_recv_thread: KVTransferThread | None = None
self.finished_store_req: set[str] = set()
@@ -162,11 +154,14 @@ class KVPoolWorker:
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
self.num_blocks,
block_shape_norm,
block_shape_pe,
)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
@@ -174,11 +169,9 @@ class KVPoolWorker:
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape)
self.kv_caches = kv_caches
self.kv_caches_base_addr = []
@@ -194,8 +187,7 @@ class KVPoolWorker:
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
@@ -208,33 +200,50 @@ class KVPoolWorker:
if self.use_layerwise:
self.get_event = threading.Event()
if self.kv_role in ['kv_producer', 'kv_both']:
if self.kv_role in ["kv_producer", "kv_both"]:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step,
ready_event_sending, self.num_layers)
self.m_store,
self.token_database,
self.block_size,
self.tp_rank,
self.dcp_size,
self.put_step,
ready_event_sending,
self.num_layers,
)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, ready_event, self.get_event)
self.m_store,
self.token_database,
self.block_size,
self.tp_rank,
self.dcp_size,
ready_event,
self.get_event,
)
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both'
] or self.consumer_is_to_put:
if self.kv_role in ["kv_producer", "kv_both"] or self.consumer_is_to_put:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, self.put_step, self.kv_role,
ready_event_sending)
self.m_store,
self.token_database,
self.block_size,
self.tp_rank,
self.dcp_size,
self.put_step,
self.kv_role,
ready_event_sending,
)
self.kv_send_thread.start()
if self.load_async:
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.m_store, self.token_database, self.block_size,
self.tp_rank, self.dcp_size, ready_event)
self.m_store, self.token_database, self.block_size, self.tp_rank, self.dcp_size, ready_event
)
self.kv_recv_thread.start()
ready_event.wait()
@@ -243,12 +252,12 @@ class KVPoolWorker:
self.layerwise_retrievers = []
for request in metadata.requests:
load_spec = request.load_spec
if load_spec is None or not load_spec.can_load: #load =0
if load_spec is None or not load_spec.can_load: # load =0
continue
token_len = request.token_len_chunk
if (load_spec.kvpool_cached_tokens % self.block_size
!= 0) and (load_spec.kvpool_cached_tokens
== token_len - 1):
if (load_spec.kvpool_cached_tokens % self.block_size != 0) and (
load_spec.kvpool_cached_tokens == token_len - 1
):
token_len = request.load_spec.kvpool_cached_tokens + 1
else:
token_len = request.load_spec.kvpool_cached_tokens
@@ -260,30 +269,27 @@ class KVPoolWorker:
else:
if self.load_async:
self.kv_recv_thread.add_request( # type: ignore[union-attr]
request, )
request,
)
else:
addr_list = []
size_list = []
key_list = []
mask_num = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
mask_num = request.load_spec.vllm_cached_tokens // self.block_size * self.block_size
for start, end, key in self.token_database.process_tokens(
token_len, request.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, request.block_ids)
token_len, request.block_hashes, mask_num
):
addr, size, _ = self.token_database.prepare_value(start, end, request.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank % len(
key_list):] + key_list[:self.tp_rank % len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list
):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list
):] + size_list[:self.tp_rank %
len(size_list)]
key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)]
addr_list_c = (
addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)]
)
size_list_c = (
size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)]
)
self.m_store.get(key_list_c, addr_list_c, size_list_c)
def wait_for_layer_load(self) -> None:
@@ -294,8 +300,7 @@ class KVPoolWorker:
num_retrieved_tokens = ret_token_mask.sum().item()
logger.debug(f"Retrieved {num_retrieved_tokens} tokens")
def save_kv_layer(self,
connector_metadata: AscendConnectorMetadata) -> None:
def save_kv_layer(self, connector_metadata: AscendConnectorMetadata) -> None:
if self.current_layer == 0:
self.layerwise_storers = []
current_event = None
@@ -336,15 +341,17 @@ class KVPoolWorker:
continue
request.current_event = current_event
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
request.req_id)
self.kv_send_thread.add_stored_request( # type: ignore[union-attr]
request.req_id
)
self.kv_send_thread.add_request( # type: ignore[union-attr]
request, )
request,
)
def retrieve_layer(
self,
request: ReqMeta,
) -> Generator[Optional[torch.Tensor], None, None]:
) -> Generator[torch.Tensor | None, None, None]:
"""
Retrieve the KV cache in a layerwise manner.
@@ -359,12 +366,14 @@ class KVPoolWorker:
return: A generator that yields Optional[torch.Tensor]. The tensor will
be the boolean mask indicating which tokens are retrieved and will
only be returned in the last iteration.
only be returned in the last iteration.
"""
token_len = request.token_len_chunk
mask_num = (
request.load_spec.vllm_cached_tokens # type: ignore[union-attr]
// self.block_size * self.block_size)
// self.block_size
* self.block_size
)
num_required_tokens = token_len - mask_num
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
@@ -373,8 +382,7 @@ class KVPoolWorker:
ends = []
keys = []
first_flag = True
for start, end, key in self.token_database.process_tokens(
token_len, request.block_hashes, mask_num):
for start, end, key in self.token_database.process_tokens(token_len, request.block_hashes, mask_num):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
@@ -386,16 +394,16 @@ class KVPoolWorker:
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
if not first_flag:
is_finish = self.get_event.wait(timeout=3) #try---cache
is_finish = self.get_event.wait(timeout=3) # try---cache
if not is_finish:
logger.info("Layerwise get failed")
self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(request.req_id,
keys_multi_chunk, starts,
ends, request.block_ids,
layer_id)
req_meta = LasyerMultiBlockReqMeta(
request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id
)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
req_meta
) # type: ignore[union-attr, call-arg, arg-type]
first_flag = False
yield None
else:
@@ -405,16 +413,14 @@ class KVPoolWorker:
yield None
retrieved_tokens = torch.sum(ret_mask)
logger.debug(f"Retrieved {retrieved_tokens} "
f"out of {num_required_tokens} "
f"out of total {token_len} tokens")
logger.debug(f"Retrieved {retrieved_tokens} out of {num_required_tokens} out of total {token_len} tokens")
yield ret_mask
def store_layer(
self,
request: ReqMeta,
current_event: Optional[torch.npu.Event],
current_event: torch.npu.Event | None,
) -> Generator[None, None, None]:
"""
Store the KV cache in a layerwise manner.
@@ -439,69 +445,88 @@ class KVPoolWorker:
starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
request.token_len_chunk, request.block_hashes):
for start, end, key in self.token_database.process_tokens(request.token_len_chunk, request.block_hashes):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
keys.append(keys_multi_layer) #[block_num,layer_num]
keys.append(keys_multi_layer) # [block_num,layer_num]
if keys:
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
keys = [list(row) for row in zip(*keys)] # [layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(request.req_id,
keys_multi_chunk, starts,
ends, request.block_ids,
layer_id,
request.is_last_chunk,
current_event)
req_meta = LasyerMultiBlockReqMeta(
request.req_id,
keys_multi_chunk,
starts,
ends,
request.block_ids,
layer_id,
request.is_last_chunk,
current_event,
)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
req_meta
) # type: ignore[union-attr, call-arg, arg-type]
yield
else:
for layer_id in range(self.num_layers):
yield
def get_finished(self,
finished_req_ids: set[str], meta:AscendConnectorMetadata) -> tuple[set[str], set[str]]:
def get_finished(self, finished_req_ids: set[str], meta: AscendConnectorMetadata) -> tuple[set[str], set[str]]:
done_sending = (
self.get_and_clear_finished_requests(
finished_req_ids, meta # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both']
or self.consumer_is_to_put else set())
finished_req_ids,
meta, # type: ignore[union-attr]
)
if self.kv_role in ["kv_producer", "kv_both"] or self.consumer_is_to_put
else set()
)
done_recving = (
self.kv_recv_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.load_async else set())
self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr]
)
if self.load_async
else set()
)
logger.debug(
"Number of completed KV cache send requests: %d, receive "
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
self.tp_rank)
"Number of completed KV cache send requests: %d, receive requests: %d, tp_rank:%d",
len(done_sending),
len(done_recving),
self.tp_rank,
)
return done_sending, done_recving
def get_and_clear_finished_requests(self, finished_req_ids, meta:AscendConnectorMetadata) -> set[str]:
def get_and_clear_finished_requests(self, finished_req_ids, meta: AscendConnectorMetadata) -> set[str]:
finished_sending = set()
for req_id in meta.preempted_req_ids:
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
req_id
)
for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr]
):
if self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
req_id] == 0 and req_id in self.finished_store_req:
if (
self.kv_send_thread.stored_requests[ # type: ignore[union-attr]
req_id
]
== 0
and req_id in self.finished_store_req
):
self.finished_store_req.remove(req_id)
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
req_id
)
for req_id in finished_req_ids:
req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr]
req_id)
req_id
)
if req_remain_jobs == 0:
finished_sending.add(req_id)
self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr]
req_id)
req_id
)
elif req_remain_jobs is not None:
self.finished_store_req.add(req_id)
@@ -522,8 +547,7 @@ class KVPoolWorker:
keys = []
try:
starts = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes):
for start, end, key in self.token_database.process_tokens(token_len, block_hashes):
if use_layerwise:
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
@@ -560,8 +584,7 @@ class KVPoolWorker:
keys = []
try:
starts = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes):
for start, end, key in self.token_database.process_tokens(token_len, block_hashes):
if use_layerwise:
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
@@ -574,25 +597,25 @@ class KVPoolWorker:
for i in range(1, min(self.tp_size, self.num_kv_head)):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1
)
multi_tp_keys.append(new_str)
for i in range(1, self.pp_size):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@pp_rank:0", f"@pp_rank:{i}", 1)
"@pp_rank:0", f"@pp_rank:{i}", 1
)
multi_tp_keys.append(new_str)
res = self.m_store.exists(
multi_tp_keys) # type: ignore[assignment]
res = self.m_store.exists(multi_tp_keys) # type: ignore[assignment]
num_block = len(keys)
if use_layerwise:
res = self.check_all_layers_exists(res, self.num_layers)
num_block = len(keys) // self.num_layers
multi_tp_values = [
res[i * num_block:(i + 1) * num_block] # type: ignore[index]
for i in range(
min(self.tp_size, self.num_kv_head) * self.pp_size)
res[i * num_block : (i + 1) * num_block] # type: ignore[index]
for i in range(min(self.tp_size, self.num_kv_head) * self.pp_size)
]
index = self.find_min_first_non_one_index(multi_tp_values)
if index != -1:
@@ -603,8 +626,7 @@ class KVPoolWorker:
return start
return end
def check_all_layers_exists(self, res: list[int],
num_layers: int) -> list[int]:
def check_all_layers_exists(self, res: list[int], num_layers: int) -> list[int]:
total_chunks = len(res) // num_layers
result = []
@@ -618,7 +640,6 @@ class KVPoolWorker:
def find_min_first_non_one_index(self, arr):
try:
return min(idx for row in arr for idx, val in enumerate(row)
if val != 1)
return min(idx for row in arr for idx, val in enumerate(row) if val != 1)
except ValueError:
return -1

View File

@@ -1,20 +1,17 @@
import time
from collections import defaultdict
from typing import Optional
from vllm.logger import logger
from vllm.utils.hashing import sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock)
from vllm.v1.core.single_type_kv_cache_manager import \
get_manager_for_kv_cache_spec
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import get_manager_for_kv_cache_spec
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.metrics.stats import (PrefixCacheStats, CachingMetrics)
from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats
from vllm.v1.request import Request
class CPUCacheStats:
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
self.enable_prefix_caching = enable_prefix_caching
self.log_stats = log_stats
@@ -27,10 +24,9 @@ class CPUCacheStats:
# Log the prefix cache hit rate every 10 seconds.
if current_time_sec - self.time_sec >= 10:
self.time_sec = current_time_sec
logger.info("CPU Prefix cache hit rate: %.1f%%",
self.cpu_prefix_cache_metrics.hit_rate * 100)
logger.info("CPU Prefix cache hit rate: %.1f%%", self.cpu_prefix_cache_metrics.hit_rate * 100)
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
def make_prefix_cache_stats(self) -> PrefixCacheStats | None:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats, or None if logging is disabled.
@@ -57,7 +53,6 @@ class CPUCacheStats:
class CPUKVCacheManager:
def __init__(
self,
kv_cache_spec: KVCacheSpec,
@@ -70,30 +65,26 @@ class CPUKVCacheManager:
self.num_cpu_blocks = num_cpu_blocks
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
self.use_eagle = use_eagle
self.block_pool = BlockPool(self.num_cpu_blocks, True,
enable_kv_cache_events)
self.block_pool = BlockPool(self.num_cpu_blocks, True, enable_kv_cache_events)
self.single_type_manager = get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=0,
)
# Record kv block hashes, avoid redundant computation.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
self.req_to_block_hashes: defaultdict[str, list[BlockHash]] = defaultdict(list)
# Record blocks touched in get_matched_num_and_touch().
self.req_to_computed_blocks: defaultdict[
str, list[KVCacheBlock]] = defaultdict(list)
self.req_to_computed_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list)
# Record the request that failed to allocate.
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
log_stats=True)
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True, log_stats=True)
# Record request that will be free after finish sending
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
# When the request requires prompt logprobs, we skip prefix caching.
if (request.sampling_params.prompt_logprobs is not None):
if request.sampling_params.prompt_logprobs is not None:
return 0, False
request_id = request.request_id
# The block hashes for the request may already be computed
@@ -119,10 +110,8 @@ class CPUKVCacheManager:
# cup prefix cache status set and log
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
num_computed_tokens)
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
self.cpu_cache_stats.prefix_cache_stats)
self.cpu_cache_stats.set_cache_stats(request.num_tokens, num_computed_tokens)
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(self.cpu_cache_stats.prefix_cache_stats)
self.cpu_cache_stats.log()
return num_computed_tokens, False
@@ -130,12 +119,10 @@ class CPUKVCacheManager:
def _release_ahead_touch(self, request_id: str):
computed_blocks = self.req_to_computed_blocks[request_id]
if computed_blocks:
self.single_type_manager.block_pool.free_blocks(
reversed(computed_blocks))
self.single_type_manager.block_pool.free_blocks(reversed(computed_blocks))
self.req_to_computed_blocks.pop(request_id, None)
def allocate_slots(self, req_to_num_tokens: dict[str, int],
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
def allocate_slots(self, req_to_num_tokens: dict[str, int], unallocated_req_ids: set[str]) -> dict[str, list[int]]:
for request_id in unallocated_req_ids:
self._free_slots(request_id)
req_to_new_blocks = {}
@@ -143,44 +130,34 @@ class CPUKVCacheManager:
if self.req_failed_to_allocate[request_id]:
continue
new_computed_blocks = self.req_to_computed_blocks[request_id]
num_blocks_to_allocate = (
self.single_type_manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=num_tokens,
new_computed_blocks=new_computed_blocks,
))
num_blocks_to_allocate = self.single_type_manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=num_tokens,
new_computed_blocks=new_computed_blocks,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
self._release_ahead_touch(request_id)
self.req_failed_to_allocate[request_id] = True
continue
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.single_type_manager.save_new_computed_blocks(
request_id, new_computed_blocks)
self.single_type_manager.save_new_computed_blocks(request_id, new_computed_blocks)
# Allocate new blocks but do not cache now.
new_blocks = self.single_type_manager.allocate_new_blocks(
request_id, num_tokens)
new_blocks = self.single_type_manager.allocate_new_blocks(request_id, num_tokens)
self.req_to_num_tokens[request_id] = num_tokens
# No need to release ref_cnt because we use officially.
self.req_to_computed_blocks.pop(request_id, None)
req_to_new_blocks[request_id] = [
block.block_id for block in new_computed_blocks + new_blocks
]
req_to_new_blocks[request_id] = [block.block_id for block in new_computed_blocks + new_blocks]
return req_to_new_blocks
def record_request_cache_and_free_slots(self, request: Request):
logger.debug(
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
)
logger.debug(f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager")
self.req_to_free[request.request_id] = request
def cache_and_free_slots(self, request_id: str):
logger.debug(
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
)
logger.debug(f"Cache and free slots for request {request_id} in cpu_kv_cache_manager")
if request_id not in self.req_to_free:
logger.Error(
f"request {request_id} not in req_to_free, maybe bug!")
logger.Error(f"request {request_id} not in req_to_free, maybe bug!")
return
request = self.req_to_free[request_id]
if not self.req_failed_to_allocate[request_id]:
@@ -189,8 +166,7 @@ class CPUKVCacheManager:
self.req_to_num_tokens[request_id],
)
self._free_slots(request_id)
logger.debug(
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
logger.debug(f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
del self.req_to_free[request_id]
def _free_slots(self, request_id: str):

View File

@@ -5,15 +5,15 @@ import queue
import threading
import time
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Sequence
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.attention.layer import Attention, MLAAttention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
@@ -23,12 +23,14 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.metadata import (
MetadataServer, MetadataServerProc, MLAConfig)
MetadataServer,
MetadataServerProc,
MLAConfig,
)
if TYPE_CHECKING:
from vllm.v1.attention.backend import AttentionMetadata #type: ignore
from vllm.forward_context import ForwardContext
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
@@ -59,20 +61,15 @@ class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
class CPUOffloadingConnector(KVConnectorBase_V1):
def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None):
def __init__(
self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None
):
self._connector_metadata = CPUOffloadingConnectorMetadata(requests={}, finished_req_ids=set())
if not vllm_config.cache_config.enable_prefix_caching:
self.connector_scheduler: Optional[
CPUOffloadingConnectorScheduler] = None
self.connector_worker: Optional[
CPUOffloadingConnectorWorker] = None
self.connector_scheduler: CPUOffloadingConnectorScheduler | None = None
self.connector_worker: CPUOffloadingConnectorWorker | None = None
elif role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = CPUOffloadingConnectorScheduler(
vllm_config)
self.connector_scheduler = CPUOffloadingConnectorScheduler(vllm_config)
self.connector_worker = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
@@ -82,11 +79,9 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
# Worker-side methods
# ==============================
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
if self.connector_worker is not None:
assert isinstance(connector_metadata,
CPUOffloadingConnectorMetadata)
assert isinstance(connector_metadata, CPUOffloadingConnectorMetadata)
self.connector_worker.bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
@@ -97,8 +92,7 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
if self.connector_worker is not None:
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
if self.connector_worker is not None:
self.connector_worker.start_load_kv()
@@ -106,53 +100,42 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
if self.connector_worker is not None:
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
def save_kv_layer(
self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
) -> None:
pass
def wait_for_save(self):
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str] | None, set[str] | None]:
assert self.connector_worker is not None
return self.connector_worker.get_finished(), None
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
if self.connector_scheduler is not None:
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens)
return 0, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
if self.connector_scheduler is not None:
return self.connector_scheduler.update_state_after_alloc(request)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
if self.connector_scheduler is not None:
return self.connector_scheduler.build_connector_meta(
scheduler_output)
return self.connector_scheduler.build_connector_meta(scheduler_output)
return KVConnectorMetadata()
def request_finished(
self, request: "Request",
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
def request_finished(self, request: "Request", block_ids: list[int]) -> tuple[bool, dict[str, Any] | None]:
if self.connector_scheduler is not None:
self.connector_scheduler.request_finished(request)
return True, None
class CPUOffloadingConnectorScheduler:
def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorScheduler")
self.vllm_config = vllm_config
@@ -165,22 +148,17 @@ class CPUOffloadingConnectorScheduler:
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
self.zmq_rpc_client.call("post_init")
if vllm_config.kv_transfer_config is not None:
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
"swap_in_threshold", 0)
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config("swap_in_threshold", 0)
else:
self.swap_in_threshold = 0
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
def get_num_new_matched_tokens(
self, ori_request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
def get_num_new_matched_tokens(self, ori_request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
request = copy.deepcopy(ori_request)
request.get_hash_new_full_blocks = None
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
"get_matched_num_and_touch", request)
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call("get_matched_num_and_touch", request)
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
self.num_cpu_computed_tokens[
request.request_id] = num_cpu_computed_tokens
self.num_cpu_computed_tokens[request.request_id] = num_cpu_computed_tokens
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
return num_cpu_computed_tokens - num_computed_tokens, load_async
else:
@@ -189,29 +167,22 @@ class CPUOffloadingConnectorScheduler:
def update_state_after_alloc(self, request: "Request"):
self.allocated_req_ids.add(request.request_id)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
num_tokens = {}
# process scheduled_new_reqs
for req in scheduler_output.scheduled_new_reqs:
req_id = req.req_id
num_tokens[req_id] = (
req.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
num_tokens[req_id] = req.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]
# process scheduled_cached_reqs
cached_reqs = scheduler_output.scheduled_cached_reqs
for idx, req_id in enumerate(cached_reqs.req_ids):
num_tokens[req_id] = (
cached_reqs.num_computed_tokens[idx] +
scheduler_output.num_scheduled_tokens[req_id])
num_tokens[req_id] = cached_reqs.num_computed_tokens[idx] + scheduler_output.num_scheduled_tokens[req_id]
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
self.allocated_req_ids -
scheduler_output.num_scheduled_tokens.keys())
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
num_tokens,
unallocated_req_ids)
unallocated_req_ids = set(
self.num_gpu_computed_tokens.keys() - self.allocated_req_ids - scheduler_output.num_scheduled_tokens.keys()
)
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", num_tokens, unallocated_req_ids)
metadata = CPUOffloadingConnectorMetadata(
requests={},
finished_req_ids=set(self.finished_req_ids),
@@ -222,22 +193,22 @@ class CPUOffloadingConnectorScheduler:
metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output.
num_scheduled_tokens[req_id],
num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id],
num_computed_tokens=req.num_computed_tokens,
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id],
)
for idx, req_id in enumerate(cached_reqs.req_ids):
gpu_block_ids = cached_reqs.new_block_ids[idx]
metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output.
num_scheduled_tokens[req_id],
num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id],
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
)
self.num_gpu_computed_tokens.clear()
self.num_cpu_computed_tokens.clear()
self.allocated_req_ids.clear()
@@ -249,12 +220,10 @@ class CPUOffloadingConnectorScheduler:
request.get_hash_new_full_blocks = None
self.finished_req_ids.append(request.request_id)
# inform metadata server to record request, and free it after finish sending
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
request)
self.zmq_rpc_client.call("record_request_cache_and_free_slots", request)
class CPUOffloadingConnectorWorker:
def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorWorker")
self.vllm_config = vllm_config
@@ -289,7 +258,7 @@ class CPUOffloadingConnectorWorker:
def init_metadata_server(self, vllm_config: VllmConfig):
self.metadata_thread = threading.Thread(
target=MetadataServerProc.run_metadata_server,
args=(vllm_config, ),
args=(vllm_config,),
)
self.metadata_thread.daemon = True
self.metadata_thread.start()
@@ -304,18 +273,15 @@ class CPUOffloadingConnectorWorker:
logger.info(f"wait for metadata server to start, error: {e}")
time.sleep(1)
def bind_connector_metadata(
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
def bind_connector_metadata(self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
for req_id, req in connector_metadata.requests.items():
if req_id in self.requests:
self.requests[req_id].update(req)
req = self.requests[req_id]
else:
self.requests[req_id] = req
for i in range(req.num_gpu_computed_tokens // self.block_size,
req.num_computed_tokens // self.block_size):
self.load_block_mapping.append(
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
for i in range(req.num_gpu_computed_tokens // self.block_size, req.num_computed_tokens // self.block_size):
self.load_block_mapping.append((req.cpu_block_ids[i], req.gpu_block_ids[i]))
for req_id in connector_metadata.finished_req_ids:
if req_id in self.requests:
self.save_input_queue.put((req_id, self.requests[req_id]))
@@ -326,11 +292,11 @@ class CPUOffloadingConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
self.gpu_kv_caches = kv_caches
model_config = self.vllm_config.model_config
mla_config: Optional[MLAConfig] = None
mla_config: MLAConfig | None = None
if model_config.use_mla:
mla_config = MLAConfig(
model_config.hf_text_config.kv_lora_rank,
model_config.hf_text_config.qk_rope_head_dim)
model_config.hf_text_config.kv_lora_rank, model_config.hf_text_config.qk_rope_head_dim
)
self.cpu_kv_caches = list(
self.zmq_rpc_client.call(
"init_cpu_kv_caches",
@@ -338,7 +304,8 @@ class CPUOffloadingConnectorWorker:
self.tp_rank,
get_kv_cache_spec(self.vllm_config),
mla_config,
).values())
).values()
)
def start_load_kv(self) -> None:
self.current_layer = 0
@@ -358,10 +325,8 @@ class CPUOffloadingConnectorWorker:
cpu_kv_caches = self.cpu_kv_caches[layer]
with torch.npu.stream(self.load_stream):
for cpu_block_id, gpu_block_id in self.load_block_mapping:
for gpu_layer_part, cpu_layer_part in zip(
gpu_kv_caches, cpu_kv_caches):
gpu_layer_part[gpu_block_id].copy_(
cpu_layer_part[cpu_block_id], non_blocking=True)
for gpu_layer_part, cpu_layer_part in zip(gpu_kv_caches, cpu_kv_caches):
gpu_layer_part[gpu_block_id].copy_(cpu_layer_part[cpu_block_id], non_blocking=True)
def get_finished(self) -> set[str]:
done_sending: set[str] = set()
@@ -380,8 +345,7 @@ class CPUOffloadingConnectorWorker:
self.done_sending_count[req_id] += 1
other_ranks_finished_ids: list[str] = []
for i in range(1, self.tp_world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
other_ranks_finished_ids.extend(self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids:
self.done_sending_count[req_id] += 1
all_done_sending: set[str] = set()
@@ -391,8 +355,7 @@ class CPUOffloadingConnectorWorker:
all_done_sending.add(req_id)
# release cpu_kv_cache after request sending finished
# to avoid rpc blocking, use thread to call rpc asynchronously
sending_finished_thread = threading.Thread(
target=self._sending_finished, args=(all_done_sending, ))
sending_finished_thread = threading.Thread(target=self._sending_finished, args=(all_done_sending,))
sending_finished_thread.daemon = True
sending_finished_thread.start()
@@ -411,11 +374,10 @@ class CPUOffloadingConnectorWorker:
while True:
req_id, req = self.save_input_queue.get()
for i in range(
req.num_cpu_computed_tokens // self.block_size,
min((req.num_computed_tokens + req.num_scheduled_tokens) //
self.block_size, len(req.cpu_block_ids))):
save_block_mapping.append(
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
req.num_cpu_computed_tokens // self.block_size,
min((req.num_computed_tokens + req.num_scheduled_tokens) // self.block_size, len(req.cpu_block_ids)),
):
save_block_mapping.append((req.gpu_block_ids[i], req.cpu_block_ids[i]))
with torch.npu.stream(self.save_stream):
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
# non-MLA: kv_layer is list[tensor], typically means [k, v].
@@ -425,13 +387,9 @@ class CPUOffloadingConnectorWorker:
start, step = 0, 1
for i in range(start, len(save_block_mapping), step):
gpu_block_id, cpu_block_id = save_block_mapping[i]
for cpu_kv_caches, gpu_kv_caches in zip(
self.cpu_kv_caches, self.gpu_kv_caches.values()):
for cpu_layer_part, gpu_layer_part in zip(
cpu_kv_caches, gpu_kv_caches):
cpu_layer_part[cpu_block_id].copy_(
gpu_layer_part[gpu_block_id],
non_blocking=True)
for cpu_kv_caches, gpu_kv_caches in zip(self.cpu_kv_caches, self.gpu_kv_caches.values()):
for cpu_layer_part, gpu_layer_part in zip(cpu_kv_caches, gpu_kv_caches):
cpu_layer_part[cpu_block_id].copy_(gpu_layer_part[gpu_block_id], non_blocking=True)
self.save_stream.synchronize()
self.save_output_queue.put(req_id)
save_block_mapping.clear()
@@ -453,8 +411,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
if vllm_config.cache_config.cache_dtype == "auto":
kv_cache_dtype = vllm_config.model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
vllm_config.cache_config.cache_dtype]
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[vllm_config.cache_config.cache_dtype]
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
@@ -472,10 +429,8 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
# using DSA. Fix the spec in vLLM is the final way.
block_size = vllm_config.cache_config.block_size
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=kv_cache_dtype)
block_size=block_size, num_kv_heads=1, head_size=attn_module.head_size, dtype=kv_cache_dtype
)
elif spec := attn_module.get_kv_cache_spec(vllm_config):
kv_cache_spec[layer_name] = spec
@@ -484,8 +439,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
if len(mamba_layers) > 0:
if vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
raise NotImplementedError("Prefix caching is not supported for Mamba yet.")
for layer_name, mamba_module in mamba_layers.items():
if spec := mamba_module.get_kv_cache_spec(vllm_config):
kv_cache_spec[layer_name] = spec

View File

@@ -1,9 +1,10 @@
import math
import os
import pickle
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Callable, Optional
from typing import Any
import torch
import vllm.envs as envs
@@ -14,8 +15,7 @@ from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import \
CPUKVCacheManager
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import CPUKVCacheManager
@dataclass
@@ -30,8 +30,7 @@ def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
elif kv_transfer_config.kv_connector == "MultiConnector":
ktcs = kv_transfer_config.kv_connector_extra_config.get(
"connectors")
ktcs = kv_transfer_config.kv_connector_extra_config.get("connectors")
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
@@ -44,7 +43,6 @@ class MetadataServer:
DEFAULT_CPU_SWAP_SPACE_GB = 800
class ZMQRPCClient:
def __init__(self, identity=None):
if identity is None:
identity = f"worker-{os.getpid()}-{id(self)}"
@@ -56,7 +54,8 @@ class MetadataServer:
zmq.DEALER, # type: ignore
bind=False,
identity=identity.encode(),
linger=0)
linger=0,
)
def call(self, func_name: str, *args, **kwargs) -> Any:
request = (func_name, args, kwargs)
@@ -74,11 +73,9 @@ class MetadataServer:
self.shared_memory_dict = memory_dict
result = {}
for key, shm in memory_dict.items():
tensor = torch.frombuffer(
shm.buf, dtype=layer_dtype).reshape(layer_size)
tensor = torch.frombuffer(shm.buf, dtype=layer_dtype).reshape(layer_size)
if mla_config is not None:
tensor = tensor.split(
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
tensor = tensor.split([mla_config.nope_dim, mla_config.rope_dim], dim=-1)
result[key] = tensor
return result
@@ -86,7 +83,7 @@ class MetadataServer:
# will be finalized by outer process
self.socket.close()
self.ctx.term()
if hasattr(self, 'shared_memory_dict'):
if hasattr(self, "shared_memory_dict"):
for shm in self.shared_memory_dict.values():
shm.close()
@@ -96,7 +93,8 @@ class MetadataServer:
kv_transfer_config = get_cpu_offload_connector(vllm_config)
assert kv_transfer_config is not None
available_memory_gb = kv_transfer_config.get_from_extra_config(
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB
)
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
logger.info(f"cpu swap space: {self.available_memory} bytes")
self.ctx = zmq.Context() # type: ignore
@@ -105,7 +103,8 @@ class MetadataServer:
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.ROUTER, # type: ignore
bind=True,
linger=0)
linger=0,
)
self.functions: dict[str, Callable] = {
"init_cpu_kv_caches": self.init_cpu_kv_caches,
"post_init": self.post_init,
@@ -133,15 +132,11 @@ class MetadataServer:
tp_rank: int,
kv_cache_specs: dict[str, AttentionSpec],
mla_config: MLAConfig,
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
MLAConfig]:
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype, MLAConfig]:
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
# follow the assumption that each layer has the same spec
layer = next(iter(kv_cache_specs.values()))
assert all([
layer.page_size_bytes == any.page_size_bytes
for any in kv_cache_specs.values()
])
assert all([layer.page_size_bytes == any.page_size_bytes for any in kv_cache_specs.values()])
use_mla = isinstance(layer, MLAAttentionSpec)
# mla shares the same kv cache among different tp
if use_mla:
@@ -154,30 +149,24 @@ class MetadataServer:
available_memory //= self.pipeline_parallel_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore
else:
available_memory //= self.world_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
for layer_name in kv_cache_specs.keys():
for layer_name in kv_cache_specs:
# only this format can share during ZeroMQ+pickle
shared_memory_dict[
layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
shared_memory_dict[layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes
)
if use_mla:
assert mla_config is not None
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, mla_config)
self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, mla_config)
else:
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, None)
self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, None)
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
self.num_cpu_blocks = num_blocks
self.layer = layer
@@ -185,23 +174,20 @@ class MetadataServer:
def post_init(self):
# different processors in data parallel may call multiple times
if hasattr(self, 'cpu_block_manager'):
if hasattr(self, "cpu_block_manager"):
return
# do shared_memory() at least once
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
assert self.num_cpu_blocks >= 0
self.cpu_block_manager = CPUKVCacheManager(self.layer,
self.num_cpu_blocks)
self.functions.update({
"get_matched_num_and_touch":
self.cpu_block_manager.get_matched_num_and_touch,
"allocate_slots":
self.cpu_block_manager.allocate_slots,
"record_request_cache_and_free_slots":
self.cpu_block_manager.record_request_cache_and_free_slots,
"cache_and_free_slots":
self.cpu_block_manager.cache_and_free_slots,
})
self.cpu_block_manager = CPUKVCacheManager(self.layer, self.num_cpu_blocks)
self.functions.update(
{
"get_matched_num_and_touch": self.cpu_block_manager.get_matched_num_and_touch,
"allocate_slots": self.cpu_block_manager.allocate_slots,
"record_request_cache_and_free_slots": self.cpu_block_manager.record_request_cache_and_free_slots,
"cache_and_free_slots": self.cpu_block_manager.cache_and_free_slots,
}
)
def serve_step(self):
client_id = self.socket.recv()
@@ -228,8 +214,7 @@ class MetadataServer:
def shutdown(self):
self.socket.close()
self.ctx.term()
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
"ipc://", "")
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace("ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
for cached in self.shared_memory.values():
@@ -239,11 +224,9 @@ class MetadataServer:
class MetadataServerProc:
@staticmethod
def run_metadata_server(vllm_config: VllmConfig):
if (not vllm_config.cache_config.enable_prefix_caching
or get_cpu_offload_connector(vllm_config) is None):
if not vllm_config.cache_config.enable_prefix_caching or get_cpu_offload_connector(vllm_config) is None:
return
shutdown_requested = False
@@ -257,7 +240,7 @@ class MetadataServerProc:
# Either SIGTERM or SIGINT will terminate the worker
# signal.signal(signal.SIGTERM, _signal_handler)
# signal.signal(signal.SIGINT, _signal_handler)
metadata_server: Optional[MetadataServer] = None
metadata_server: MetadataServer | None = None
try:
metadata_server = MetadataServer(vllm_config)
logger.info("Metadata server started.")

View File

@@ -4,19 +4,21 @@ from typing import TYPE_CHECKING, Any, Optional
import torch
from ucm.integration.vllm.ucm_connector import UCMConnector
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
# isort: off
if TYPE_CHECKING:
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics, KVConnectorStats, PromMetric, PromMetricT)
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -25,16 +27,13 @@ if TYPE_CHECKING:
class UCMConnectorV1(KVConnectorBase_V1):
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config)
super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config)
assert vllm_config.kv_transfer_config is not None
ImplCls = UCMConnector
@@ -60,8 +59,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
"""
self._ucm_engine.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs: Any) -> None:
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
@@ -110,8 +108,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
**kwargs)
self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs)
def wait_for_save(self) -> None:
"""
@@ -131,8 +128,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
"""
self._ucm_engine.clear_connector_metadata()
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
"""Set the connector metadata from the scheduler.
This function should be called by the model runner every time
@@ -175,20 +171,15 @@ class UCMConnectorV1(KVConnectorBase_V1):
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
return self._ucm_engine.get_num_new_matched_tokens(
request, num_computed_tokens)
return self._ucm_engine.get_num_new_matched_tokens(request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int) -> None:
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int) -> None:
"""
Update KVConnector state after block allocation.
"""
self._ucm_engine.update_state_after_alloc(request, blocks,
num_external_tokens)
self._ucm_engine.update_state_after_alloc(request, blocks, num_external_tokens)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""
Build the connector metadata for this step.
@@ -222,10 +213,7 @@ class UCMConnectorV1(KVConnectorBase_V1):
# ==============================
@classmethod
def build_kv_connector_stats(
cls,
data: dict[str, Any] | None = None
) -> Optional["KVConnectorStats"]:
def build_kv_connector_stats(cls, data: dict[str, Any] | None = None) -> Optional["KVConnectorStats"]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,

View File

@@ -1,19 +1,16 @@
import ipaddress
import threading
from typing import Optional
from mooncake.engine import TransferEngine # type: ignore
class GlobalTE():
class GlobalTE:
def __init__(self):
self.transfer_engine = None
self.is_register_buffer: bool = False
self.transfer_engine_lock = threading.Lock()
self.register_buffer_lock = threading.Lock()
def get_transfer_engine(self, hostname: str, device_name: Optional[str]):
def get_transfer_engine(self, hostname: str, device_name: str | None):
if self.transfer_engine is None:
with self.transfer_engine_lock:
# Double-Checked Locking
@@ -22,12 +19,9 @@ class GlobalTE():
raise RuntimeError("mooncake is not available")
self.transfer_engine = TransferEngine()
device_name = device_name if device_name is not None else ""
ret_value = self.transfer_engine.initialize(
hostname, "P2PHANDSHAKE", "ascend", device_name)
ret_value = self.transfer_engine.initialize(hostname, "P2PHANDSHAKE", "ascend", device_name)
if ret_value != 0:
raise RuntimeError(
f"TransferEngine initialization failed with ret_value: {ret_value}"
)
raise RuntimeError(f"TransferEngine initialization failed with ret_value: {ret_value}")
return self.transfer_engine
def register_buffer(self, ptrs: list[int], sizes: list[int]):

View File

@@ -6,8 +6,7 @@ import torch.distributed as dist
from vllm_ascend.distributed.parallel_state import get_p_tp_group
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
value: torch.TensorType):
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, value: torch.TensorType):
if pd_tp_ratio <= 1:
return None, None
elif key is None or value is None:
@@ -20,22 +19,17 @@ def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
num_kv_heads = input_tensor.size(1)
output_tensor = torch.zeros_like(input_tensor)
dist.all_to_all_single(output_tensor,
input_tensor,
group=get_p_tp_group().device_group)
dist.all_to_all_single(output_tensor, input_tensor, group=get_p_tp_group().device_group)
input_tensor = 0
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
output_tensor = 0
return result
def rearrange_output(base_output: torch.Tensor, cut_num: int,
num_kv_heads: int):
def rearrange_output(base_output: torch.Tensor, cut_num: int, num_kv_heads: int):
size_0 = base_output.size(0)
if size_0 % cut_num != 0:
raise ValueError(
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
)
raise ValueError(f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]")
chunk_size = size_0 // cut_num
reshaped = base_output.view(cut_num, chunk_size, -1)
transposed = reshaped.transpose(0, 1)
@@ -46,16 +40,13 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
return tensor[int(offset) :]
def get_transfer_timeout_value():
ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
if len(ascend_transfer_timeout) > 0:
return int(ascend_transfer_timeout)
hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT',
'20')) # type: ignore
hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT',
'7')) # type: ignore
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
3000)
hccl_rdma_timeout = int(os.getenv("HCCL_RDMA_TIMEOUT", "20")) # type: ignore
hccl_rdma_retry_cnt = int(os.getenv("HCCL_RDMA_RETRY_CNT", "7")) # type: ignore
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000)

View File

@@ -4,8 +4,7 @@ from vllm.logger import init_logger
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import AttentionBackend # type: ignore
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
TransferResult, TransferSpec)
from vllm.v1.kv_offload.worker.worker import OffloadingHandler, TransferResult, TransferSpec
logger = init_logger(__name__)
@@ -44,7 +43,6 @@ def expand_block_ids(
class CpuNpuOffloadingHandler(OffloadingHandler):
def __init__(
self,
gpu_block_size: int,
@@ -81,20 +79,22 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
self.cpu_tensors.append((
torch.zeros(
cpu_shape,
dtype=gpu_tensor[0].dtype,
device="cpu",
pin_memory=pin_memory,
),
torch.zeros(
cpu_shape,
dtype=gpu_tensor[0].dtype,
device="cpu",
pin_memory=pin_memory,
),
))
self.cpu_tensors.append(
(
torch.zeros(
cpu_shape,
dtype=gpu_tensor[0].dtype,
device="cpu",
pin_memory=pin_memory,
),
torch.zeros(
cpu_shape,
dtype=gpu_tensor[0].dtype,
device="cpu",
pin_memory=pin_memory,
),
)
)
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
logger.info("start transfer_async...")
@@ -123,9 +123,7 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor
src_sub_block_count = src_blocks.size * src_block_size_factor
assert (
src_sub_block_count == dst_blocks.size * dst_block_size_factor -
dst_sub_blocks_to_skip)
assert src_sub_block_count == dst_blocks.size * dst_block_size_factor - dst_sub_blocks_to_skip
src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64)
expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0])
@@ -137,18 +135,14 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
)
src_to_dst_tensor = torch.from_numpy(src_to_dst)
event = self.events_pool.pop(
) if self.events_pool else torch.npu.Event()
event = self.events_pool.pop() if self.events_pool else torch.npu.Event()
with torch.npu.stream(stream):
for src_tensor, dst_tensor in zip(src_tensors, dst_tensors):
src_key_cache, src_value_cache = src_tensor[0], src_tensor[1]
dst_key_cache, dst_value_cache = dst_tensor[0], dst_tensor[1]
torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache,
src_to_dst_tensor)
torch.ops._C_ascend.swap_blocks(src_value_cache,
dst_value_cache,
src_to_dst_tensor)
torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
torch.ops._C_ascend.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
event.record(stream)
@@ -175,4 +169,4 @@ class CpuNpuOffloadingHandler(OffloadingHandler):
event = self.transfer_events.get(job_id)
if event is not None:
# This will block until the NPU event is complete
event.synchronize()
event.synchronize()

View File

@@ -1,48 +1,40 @@
from collections.abc import Iterator
from typing import Optional
import torch
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend # type: ignore
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm_ascend.kv_offload.cpu_npu import CpuNpuOffloadingHandler
class NPUOffloadingSpec(OffloadingSpec):
def __init__(self,
vllm_config: VllmConfig,
kv_cache_config: Optional[KVCacheConfig] = None):
def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig | None = None):
super().__init__(vllm_config, kv_cache_config)
num_cpu_blocks = self.extra_config.get("num_cpu_blocks")
if not num_cpu_blocks:
raise Exception(
"num_cpu_blocks must be specified in kv_connector_extra_config"
)
raise Exception("num_cpu_blocks must be specified in kv_connector_extra_config")
self.num_cpu_blocks: int = num_cpu_blocks
# scheduler-side
self._manager: Optional[OffloadingManager] = None
self._manager: OffloadingManager | None = None
# worker-side
self._handler: Optional[OffloadingHandler] = None
self._handler: OffloadingHandler | None = None
def get_manager(self) -> OffloadingManager:
if not self._manager:
kv_events_config = self.vllm_config.kv_events_config
enable_events = (kv_events_config is not None
and kv_events_config.enable_kv_cache_events)
enable_events = kv_events_config is not None and kv_events_config.enable_kv_cache_events
self._manager = LRUOffloadingManager(
CPUBackend(block_size=self.offloaded_block_size,
num_blocks=self.num_cpu_blocks),
CPUBackend(block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks),
enable_events=enable_events,
)
return self._manager
@@ -51,8 +43,7 @@ class NPUOffloadingSpec(OffloadingSpec):
self,
kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
OffloadingHandler]]:
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
if not self._handler:
self._handler = CpuNpuOffloadingHandler(
attn_backends=attn_backends,

View File

@@ -16,11 +16,13 @@
import torch
def bgmv_shrink(inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0):
def bgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
return torch.ops._C_ascend.bgmv_shrink(
inputs,
lora_a_weights,
@@ -30,11 +32,13 @@ def bgmv_shrink(inputs: torch.Tensor,
)
def bgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
return torch.ops._C_ascend.bgmv_expand(
inputs,
lora_b_weights,
@@ -45,16 +49,18 @@ def bgmv_expand(inputs: torch.Tensor,
)
def bgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights,
lora_indices_tensor, output_tensor,
slice_offset, slice_size)
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
return torch.ops._C_ascend.bgmv_expand(
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset, slice_size
)
def sgmv_shrink(
@@ -69,21 +75,23 @@ def sgmv_shrink(
token_nums: int,
scaling: float,
):
return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights,
lora_indices_tensor, seq_len_tensor,
output_tensor, scaling)
return torch.ops._C_ascend.sgmv_shrink(
inputs, lora_a_weights, lora_indices_tensor, seq_len_tensor, output_tensor, scaling
)
def sgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False):
def sgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
):
return torch.ops._C_ascend.sgmv_expand(
inputs,
lora_b_weights,
@@ -95,19 +103,20 @@ def sgmv_expand(inputs: torch.Tensor,
)
def sgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False):
return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights,
lora_indices_tensor, seq_len_tensor,
output_tensor, slice_offset,
slice_size)
def sgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
):
return torch.ops._C_ascend.sgmv_expand(
inputs, lora_b_weights, lora_indices_tensor, seq_len_tensor, output_tensor, slice_offset, slice_size
)

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional, Tuple, Union
from collections.abc import Callable
import torch
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
@@ -18,26 +18,30 @@ class PunicaWrapperNPU(PunicaWrapperBase):
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
def __init__(self, max_num_batched_tokens: int, max_batches: int, device: torch.device | str, **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
refresh_all_lora_classes()
self.lora_config = kwargs.get("lora_config")
if get_ascend_device_type() == AscendDeviceType._310P or (
self.lora_config is not None
and self.lora_config.max_lora_rank >= 128):
from vllm.lora.ops.torch_ops import (bgmv_expand,
bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice,
sgmv_shrink)
self.lora_config is not None and self.lora_config.max_lora_rank >= 128
):
from vllm.lora.ops.torch_ops import (
bgmv_expand,
bgmv_expand_slice,
bgmv_shrink,
sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
else:
from vllm_ascend.lora.lora_ops import (bgmv_expand,
bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice,
sgmv_shrink)
from vllm_ascend.lora.lora_ops import (
bgmv_expand,
bgmv_expand_slice,
bgmv_shrink,
sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
self.bgmv_expand = bgmv_expand
self.bgmv_expand_slice = bgmv_expand_slice
self.bgmv_shrink = bgmv_shrink
@@ -52,7 +56,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
# No LoRA request, so return directly
if self.no_lora:
return
self.sgmv_shrink(
@@ -79,7 +83,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
w_t_all: torch.Tensor,
add_inputs: bool,
):
#No LoRA request, so return directly
# No LoRA request, so return directly
if self.no_lora:
return
self.sgmv_expand(
@@ -108,7 +112,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
y_slice_size: int,
add_inputs: bool,
):
#No LoRA request, so return directly
# No LoRA request, so return directly
if self.no_lora:
return
self.sgmv_expand_slice(
@@ -130,8 +134,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
y_slice_size: int,
add_inputs: bool,
):
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices,
y_offset, y_slice_size, add_inputs)
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs)
def _apply_expand(
self,
@@ -148,13 +151,10 @@ class PunicaWrapperNPU(PunicaWrapperBase):
GEMM of lora'b.
"""
expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun: Callable = self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, scale: float):
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
@@ -165,14 +165,18 @@ class PunicaWrapperNPU(PunicaWrapperBase):
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (self._shrink_prefill
if self.is_prefill else self._shrink_decode)
shrink_fun: Callable = self._shrink_prefill if self.is_prefill else self._shrink_decode
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float, **kwargs):
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
@@ -194,18 +198,19 @@ class PunicaWrapperNPU(PunicaWrapperBase):
x = x.view(-1, x.shape[-1])
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
scale)
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
def add_expand(self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> None:
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: tuple[torch.Tensor, ...] | None,
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
@@ -229,8 +234,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
y = y.view(-1, y.shape[-1])
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
@@ -243,12 +247,9 @@ class PunicaWrapperNPU(PunicaWrapperBase):
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs) -> None:
def add_lora_embedding(
self, y: torch.Tensor, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
@@ -263,21 +264,22 @@ class PunicaWrapperNPU(PunicaWrapperBase):
"""
# Embedding layer only need expand op
expand_fun: Callable = (self._expand_prefill
if self.is_prefill else self._expand_decode)
expand_fun: Callable = self._expand_prefill if self.is_prefill else self._expand_decode
x = x.to(torch.float32)
expand_fun(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> None:
"""
Applicable to linear-related lora.
@@ -308,27 +310,22 @@ class PunicaWrapperNPU(PunicaWrapperBase):
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = tuple(
torch.zeros(
(x.size(0), r), dtype=torch.float32, device=x.device)
for _ in range(len(output_slices)))
torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) for _ in range(len(output_slices))
)
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y,
buffer,
lora_b_stacked,
None,
output_slices,
add_inputs=True,
**kwargs)
self.add_expand(y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs)
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
@@ -350,9 +347,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
r = lora_b_stacked.size(-1)
if buffer is None:
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
indices = self.sampler_indices

View File

@@ -1,91 +1,75 @@
from typing import Optional
import vllm
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
from vllm.lora.layers import (
ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA,
)
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
AscendRowParallelLinear)
from vllm_ascend.ops.vocab_parallel_embedding import \
AscendVocabParallelEmbedding
from vllm_ascend.ops.linear import (
AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
AscendRowParallelLinear,
)
from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendColumnParallelLinear
class AscendMergedColumnParallelLinearWithLoRA(
MergedColumnParallelLinearWithLoRA):
class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendMergedColumnParallelLinear
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendRowParallelLinear
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendVocabParallelEmbedding
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: list,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(
packed_modules_list) == 1
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
@@ -93,18 +77,28 @@ class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
model_config: PretrainedConfig | None,
) -> bool:
return (type(source_layer) is AscendQKVParallelLinear
and len(packed_modules_list) == 3)
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
def refresh_all_lora_classes():
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(
AscendMergedColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(
AscendMergedQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)